@@ -174,13 +174,19 @@ end
174174Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
175175By using this instead of tree_mapreduce, we can take advantage of early exits.
176176"""
177- function any (f:: F , tree:: AbstractNode ) where {F<: Function }
178- if tree. degree == 0
179- return @inline (f (tree)):: Bool
180- elseif tree. degree == 1
181- return @inline (f (tree)):: Bool || any (f, tree. l)
182- else
183- return @inline (f (tree)):: Bool || any (f, tree. l) || any (f, tree. r)
177+ @generated function any (f:: F , tree:: AbstractNode{D} ) where {F<: Function ,D}
178+ quote
179+ deg = tree. degree
180+
181+ deg == 0 && return @inline (f (tree))
182+
183+ return (
184+ @inline (f (tree)) || Base. Cartesian. @nif (
185+ $ D, i -> deg == i, i -> let cs = children (tree, Val (i))
186+ Base. Cartesian. @nany (i, j -> any (f, cs[j]))
187+ end
188+ )
189+ )
184190 end
185191end
186192
@@ -189,49 +195,49 @@ function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode)
189195end
190196function Base.:(== )(a:: N , b:: N ):: Bool where {N<: AbstractExpressionNode }
191197 if preserve_sharing (N)
192- return inner_is_equal_shared (a, b, Dict {UInt,Nothing} (), Dict {UInt,Nothing} ())
198+ return inner_is_equal (a, b, (; a = Dict {UInt,Nothing} (), b = Dict {UInt,Nothing} () ))
193199 else
194- return inner_is_equal (a, b)
200+ return inner_is_equal (a, b, nothing )
195201 end
196202end
197- function inner_is_equal (a, b)
198- (degree = a. degree) != b. degree && return false
199- if degree == 0
200- return leaf_equal (a, b)
201- elseif degree == 1
202- return branch_equal (a, b) && inner_is_equal (a. l, b. l)
203- else
204- return branch_equal (a, b) && inner_is_equal (a. l, b. l) && inner_is_equal (a. r, b. r)
205- end
206- end
207- function inner_is_equal_shared (a, b, id_map_a, id_map_b)
208- id_a = objectid (a)
209- id_b = objectid (b)
210- has_a = haskey (id_map_a, id_a)
211- has_b = haskey (id_map_b, id_b)
212-
213- if has_a && has_b
214- return true
215- elseif has_a ⊻ has_b
216- return false
217- end
218-
219- (degree = a. degree) != b. degree && return false
203+ @generated function inner_is_equal (
204+ a:: AbstractNode{D} , b:: AbstractNode{D} , id_maps:: Union{Nothing,NamedTuple}
205+ ) where {D}
206+ quote
207+ ids = ! isnothing (id_maps) ? (; a= objectid (a), b= objectid (b)) : nothing
208+
209+ if ! isnothing (id_maps)
210+ has_a = haskey (id_maps. a, ids. a)
211+ has_b = haskey (id_maps. b, ids. b)
212+ if has_a && has_b
213+ return true
214+ elseif has_a ⊻ has_b
215+ return false
216+ end
217+ end
220218
221- result = if degree == 0
222- leaf_equal (a, b)
223- elseif degree == 1
224- branch_equal (a, b) && inner_is_equal_shared (a. l, b. l, id_map_a, id_map_b)
225- else
226- branch_equal (a, b) &&
227- inner_is_equal_shared (a. l, b. l, id_map_a, id_map_b) &&
228- inner_is_equal_shared (a. r, b. r, id_map_a, id_map_b)
219+ deg = a. degree
220+ result = if deg != b. degree
221+ false
222+ elseif deg == 0
223+ leaf_equal (a, b)
224+ else
225+ (
226+ branch_equal (a, b) && Base. Cartesian. @nif (
227+ $ D,
228+ i -> deg == i,
229+ i -> let cs_a = children (a, Val (i)), cs_b = children (b, Val (i))
230+ Base. Cartesian. @nall (i, j -> inner_is_equal (cs_a[j], cs_b[j], id_maps))
231+ end
232+ )
233+ )
234+ end
235+ if ! isnothing (ids)
236+ id_maps. a[ids. a] = nothing
237+ id_maps. b[ids. b] = nothing
238+ end
239+ return result
229240 end
230-
231- id_map_a[id_a] = nothing
232- id_map_b[id_b] = nothing
233-
234- return result
235241end
236242
237243@inline function branch_equal (a:: AbstractExpressionNode , b:: AbstractExpressionNode )
240246@inline function leaf_equal (
241247 a:: AbstractExpressionNode{T1} , b:: AbstractExpressionNode{T2}
242248) where {T1,T2}
243- (constant = a. constant) != b. constant && return false
249+ constant = a. constant
250+ constant != b. constant && return false
244251 if constant
245252 return a. val:: T1 == b. val:: T2
246253 else
@@ -521,11 +528,10 @@ using `convert(T1, tree.val)` at constant nodes.
521528"""
522529function convert (
523530 :: Type{N1} , tree:: N2
524- ) where {T1,T2,D1,D2, N1<: AbstractExpressionNode{T1,D1 } ,N2<: AbstractExpressionNode{T2,D2 } }
531+ ) where {T1,T2,N1<: AbstractExpressionNode{T1} ,N2<: AbstractExpressionNode{T2} }
525532 if N1 === N2
526533 return tree
527534 end
528- @assert max_degree (N1) == max_degree (N2)
529535 return tree_mapreduce (
530536 Base. Fix1 (leaf_convert, N1),
531537 identity,
@@ -535,11 +541,6 @@ function convert(
535541 )
536542 # TODO : Need to allow user to overload this!
537543end
538- function convert (
539- :: Type{N1} , tree:: N2
540- ) where {T1,T2,D2,N1<: AbstractExpressionNode{T1} ,N2<: AbstractExpressionNode{T2,D2} }
541- return convert (with_max_degree (N1, Val (D2)), tree)
542- end
543544function convert (
544545 :: Type{N1} , tree:: N2
545546) where {T2,N1<: AbstractExpressionNode ,N2<: AbstractExpressionNode{T2} }
0 commit comments