Skip to content

Commit aca5395

Browse files
committed
feat: any and == working with n-arity nodes
1 parent cd27db1 commit aca5395

File tree

5 files changed

+697
-57
lines changed

5 files changed

+697
-57
lines changed

src/Node.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,11 @@ macro make_accessors(node_type)
216216
end)
217217
end
218218

219-
@make_accessors Node{T,2} where {T}
220-
@make_accessors GraphNode{T,2} where {T}
219+
# @make_accessors Node{T,2} where {T}
220+
# @make_accessors GraphNode{T,2} where {T}
221+
@make_accessors Node
222+
@make_accessors GraphNode
223+
# TODO: Disable the `.l` accessors eventually, once the codebase is fully generic
221224

222225
@inline children(node::AbstractNode) = node.children
223226
@inline function children(node::AbstractNode, ::Val{n}) where {n}

src/base.jl

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,19 @@ end
174174
Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
175175
By using this instead of tree_mapreduce, we can take advantage of early exits.
176176
"""
177-
function any(f::F, tree::AbstractNode) where {F<:Function}
178-
if tree.degree == 0
179-
return @inline(f(tree))::Bool
180-
elseif tree.degree == 1
181-
return @inline(f(tree))::Bool || any(f, tree.l)
182-
else
183-
return @inline(f(tree))::Bool || any(f, tree.l) || any(f, tree.r)
177+
@generated function any(f::F, tree::AbstractNode{D}) where {F<:Function,D}
178+
quote
179+
deg = tree.degree
180+
181+
deg == 0 && return @inline(f(tree))
182+
183+
return (
184+
@inline(f(tree)) || Base.Cartesian.@nif(
185+
$D, i -> deg == i, i -> let cs = children(tree, Val(i))
186+
Base.Cartesian.@nany(i, j -> any(f, cs[j]))
187+
end
188+
)
189+
)
184190
end
185191
end
186192

@@ -189,49 +195,49 @@ function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode)
189195
end
190196
function Base.:(==)(a::N, b::N)::Bool where {N<:AbstractExpressionNode}
191197
if preserve_sharing(N)
192-
return inner_is_equal_shared(a, b, Dict{UInt,Nothing}(), Dict{UInt,Nothing}())
198+
return inner_is_equal(a, b, (; a=Dict{UInt,Nothing}(), b=Dict{UInt,Nothing}()))
193199
else
194-
return inner_is_equal(a, b)
200+
return inner_is_equal(a, b, nothing)
195201
end
196202
end
197-
function inner_is_equal(a, b)
198-
(degree = a.degree) != b.degree && return false
199-
if degree == 0
200-
return leaf_equal(a, b)
201-
elseif degree == 1
202-
return branch_equal(a, b) && inner_is_equal(a.l, b.l)
203-
else
204-
return branch_equal(a, b) && inner_is_equal(a.l, b.l) && inner_is_equal(a.r, b.r)
205-
end
206-
end
207-
function inner_is_equal_shared(a, b, id_map_a, id_map_b)
208-
id_a = objectid(a)
209-
id_b = objectid(b)
210-
has_a = haskey(id_map_a, id_a)
211-
has_b = haskey(id_map_b, id_b)
212-
213-
if has_a && has_b
214-
return true
215-
elseif has_a has_b
216-
return false
217-
end
218-
219-
(degree = a.degree) != b.degree && return false
203+
@generated function inner_is_equal(
204+
a::AbstractNode{D}, b::AbstractNode{D}, id_maps::Union{Nothing,NamedTuple}
205+
) where {D}
206+
quote
207+
ids = !isnothing(id_maps) ? (; a=objectid(a), b=objectid(b)) : nothing
208+
209+
if !isnothing(id_maps)
210+
has_a = haskey(id_maps.a, ids.a)
211+
has_b = haskey(id_maps.b, ids.b)
212+
if has_a && has_b
213+
return true
214+
elseif has_a has_b
215+
return false
216+
end
217+
end
220218

221-
result = if degree == 0
222-
leaf_equal(a, b)
223-
elseif degree == 1
224-
branch_equal(a, b) && inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b)
225-
else
226-
branch_equal(a, b) &&
227-
inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b) &&
228-
inner_is_equal_shared(a.r, b.r, id_map_a, id_map_b)
219+
deg = a.degree
220+
result = if deg != b.degree
221+
false
222+
elseif deg == 0
223+
leaf_equal(a, b)
224+
else
225+
(
226+
branch_equal(a, b) && Base.Cartesian.@nif(
227+
$D,
228+
i -> deg == i,
229+
i -> let cs_a = children(a, Val(i)), cs_b = children(b, Val(i))
230+
Base.Cartesian.@nall(i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps))
231+
end
232+
)
233+
)
234+
end
235+
if !isnothing(ids)
236+
id_maps.a[ids.a] = nothing
237+
id_maps.b[ids.b] = nothing
238+
end
239+
return result
229240
end
230-
231-
id_map_a[id_a] = nothing
232-
id_map_b[id_b] = nothing
233-
234-
return result
235241
end
236242

237243
@inline function branch_equal(a::AbstractExpressionNode, b::AbstractExpressionNode)
@@ -240,7 +246,8 @@ end
240246
@inline function leaf_equal(
241247
a::AbstractExpressionNode{T1}, b::AbstractExpressionNode{T2}
242248
) where {T1,T2}
243-
(constant = a.constant) != b.constant && return false
249+
constant = a.constant
250+
constant != b.constant && return false
244251
if constant
245252
return a.val::T1 == b.val::T2
246253
else
@@ -521,11 +528,10 @@ using `convert(T1, tree.val)` at constant nodes.
521528
"""
522529
function convert(
523530
::Type{N1}, tree::N2
524-
) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}}
531+
) where {T1,T2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2}}
525532
if N1 === N2
526533
return tree
527534
end
528-
@assert max_degree(N1) == max_degree(N2)
529535
return tree_mapreduce(
530536
Base.Fix1(leaf_convert, N1),
531537
identity,
@@ -535,11 +541,6 @@ function convert(
535541
)
536542
# TODO: Need to allow user to overload this!
537543
end
538-
function convert(
539-
::Type{N1}, tree::N2
540-
) where {T1,T2,D2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2,D2}}
541-
return convert(with_max_degree(N1, Val(D2)), tree)
542-
end
543544
function convert(
544545
::Type{N1}, tree::N2
545546
) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}}

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using TestItemRunner
44
# Check if SR_ENZYME_TEST is set in env
55
test_name = split(get(ENV, "SR_TEST", "main"), ",")
66

7-
unknown_tests = filter(Base.Fix2(, ["enzyme", "jet", "main"]), test_name)
7+
unknown_tests = filter(Base.Fix2(, ["enzyme", "jet", "main", "narity"]), test_name)
88

99
if !isempty(unknown_tests)
1010
error("Unknown test names: $unknown_tests")
@@ -49,3 +49,7 @@ if "main" in test_name
4949
include("unittest.jl")
5050
@run_package_tests
5151
end
52+
if "narity" in test_name
53+
include("test_n_arity_nodes.jl")
54+
@run_package_tests filter = ti -> (:narity in ti.tags)
55+
end

test/test_extra_node_fields.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D}
1111
frozen::Bool # Extra field!
1212
feature::UInt16
1313
op::UInt8
14-
children::NTuple{D,Base.RefValue{FrozenNode{T,D}}}
14+
children::NTuple{D,FrozenNode{T,D}}
1515

1616
function FrozenNode{_T,_D}() where {_T,_D}
1717
n = new{_T,_D}()

0 commit comments

Comments
 (0)