Skip to content

Commit 008dfbc

Browse files
committed
feat: better interface for children
1 parent b78097a commit 008dfbc

File tree

8 files changed

+70
-54
lines changed

8 files changed

+70
-54
lines changed

src/Evaluate.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EvaluateModule
33
using DispatchDoctor: @stable, @unstable
44

55
import ..NodeModule:
6-
AbstractExpressionNode, constructorof, max_degree, children, with_type_parameters
6+
AbstractExpressionNode, constructorof, max_degree, get_children, with_type_parameters
77
import ..StringsModule: string_tree
88
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
99
import ..UtilsModule: fill_similar, counttuple, ResultOk
@@ -343,7 +343,7 @@ end
343343
) where {T,degree,OPS}
344344
nops = length(OPS.types[degree].types)
345345
return quote
346-
cs = children(tree, Val($degree))
346+
cs = get_children(tree, Val($degree))
347347
Base.Cartesian.@nexprs(
348348
$degree,
349349
i -> begin
@@ -727,7 +727,7 @@ end
727727
) where {T,degree,OPS}
728728
nops = length(OPS.types[degree].types)
729729
get_inputs = quote
730-
cs = children(tree, Val($degree))
730+
cs = get_children(tree, Val($degree))
731731
Base.Cartesian.@nexprs(
732732
$degree,
733733
i -> begin

src/Interfaces.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using ..NodeModule:
1212
constructorof,
1313
default_allocator,
1414
with_type_parameters,
15-
children,
15+
get_children,
1616
leaf_copy,
1717
leaf_convert,
1818
leaf_hash,
@@ -226,12 +226,12 @@ function _check_create_node(tree::AbstractExpressionNode)
226226
NT = with_type_parameters(N, Float16)
227227
return NT() isa NT
228228
end
229-
function _check_children(tree::AbstractExpressionNode{T,D}) where {T,D}
229+
function _check_get_children(tree::AbstractExpressionNode{T,D}) where {T,D}
230230
tree.degree == 0 && return true
231-
return children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} &&
232-
children(tree, Val(D)) isa Tuple &&
233-
length(children(tree, Val(D))) == D &&
234-
length(children(tree, Val(1))) == 1
231+
return get_children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} &&
232+
get_children(tree, Val(D)) isa Tuple &&
233+
length(get_children(tree, Val(D))) == D &&
234+
length(get_children(tree, Val(1))) == 1
235235
end
236236
function _check_copy(tree::AbstractExpressionNode)
237237
return copy(tree) isa typeof(tree)
@@ -308,19 +308,19 @@ function _check_leaf_equal(tree::AbstractExpressionNode)
308308
end
309309
function _check_branch_copy(tree::AbstractExpressionNode)
310310
tree.degree == 0 && return true
311-
return branch_copy(tree, children(tree, Val(tree.degree))...) isa typeof(tree)
311+
return branch_copy(tree, get_children(tree, Val(tree.degree))...) isa typeof(tree)
312312
end
313313
function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T}
314314
tree.degree == 0 && return true
315315
new_branch = constructorof(typeof(tree))(; val=zero(T))
316316
ret = branch_copy_into!(
317-
new_branch, tree, map(copy, children(tree, Val(tree.degree)))...
317+
new_branch, tree, map(copy, get_children(tree, Val(tree.degree)))...
318318
)
319319
return new_branch == tree && ret === new_branch
320320
end
321321
function _check_branch_convert(tree::AbstractExpressionNode)
322322
tree.degree == 0 && return true
323-
return branch_convert(typeof(tree), tree, children(tree, Val(tree.degree))...) isa
323+
return branch_convert(typeof(tree), tree, get_children(tree, Val(tree.degree))...) isa
324324
typeof(tree)
325325
end
326326
function _check_branch_hash(tree::AbstractExpressionNode)
@@ -367,7 +367,7 @@ end
367367
ni_components = (
368368
mandatory = (
369369
create_node = "creates a new instance of the node type" => _check_create_node,
370-
children = "returns the children of the node" => _check_children,
370+
get_children = "returns the children of the node" => _check_get_children,
371371
copy = "returns a copy of the tree" => _check_copy,
372372
hash = "returns the hash of the tree" => _check_hash,
373373
any = "checks if any element of the tree satisfies a condition" => _check_any,

src/Node.jl

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,34 +176,56 @@ function get_poison(n::AbstractNode)
176176
return n
177177
end
178178

179+
@inline function get_children(node::AbstractNode)
180+
return getfield(node, :children)
181+
end
182+
@inline function get_children(node::AbstractNode, ::Val{n}) where {n}
183+
cs = get_children(node)
184+
return ntuple(i -> cs[i], Val(Int(n)))
185+
end
186+
@inline function get_child(n::AbstractNode{D}, i::Int) where {D}
187+
return get_children(n)[i]
188+
end
189+
@inline function set_child!(n::AbstractNode{D}, child::AbstractNode{D}, i::Int) where {D}
190+
set_children!(n, Base.setindex(get_children(n), child, i))
191+
return child
192+
end
193+
@inline function set_children!(n::AbstractNode{D}, children::NTuple{D2,AbstractNode{D}}) where {D,D2}
194+
if D === D2
195+
n.children = children
196+
else
197+
poison = get_poison(n)
198+
# We insert poison at the end of the tuple so that
199+
# errors will appear loudly if accessed.
200+
# This poison should be efficient to insert. So
201+
# for simplicity, we can just use poison == n, which
202+
# will trigger infinite recursion errors if accessed.
203+
n.children = ntuple(i -> i <= D2 ? children[i] : poison, Val(D))
204+
end
205+
end
206+
179207
macro make_accessors(node_type)
180208
esc(quote
181209
@inline function Base.getproperty(n::$node_type, k::Symbol)
182210
if k == :l
183211
# TODO: Should a depwarn be raised here? Or too slow?
184-
return getfield(n, :children)[1]
212+
return $(get_child)(n, 1)
185213
elseif k == :r
186-
return getfield(n, :children)[2]
214+
return $(get_child)(n, 2)
187215
else
188216
return getfield(n, k)
189217
end
190218
end
191219
@inline function Base.setproperty!(n::$node_type, k::Symbol, v)
192220
if k == :l
193221
if isdefined(n, :children)
194-
old = getfield(n, :children)
195-
setfield!(n, :children, (v, old[2]))
196-
v
222+
$(set_child!)(n, v, 1)
197223
else
198-
poison = $(get_poison)(n)
199-
setfield!(n, :children, (v, poison))
224+
$(set_children!)(n, (v,))
200225
v
201226
end
202227
elseif k == :r
203-
# TODO: Remove this assert once we know that this is safe
204-
old = getfield(n, :children)
205-
setfield!(n, :children, (old[1], v))
206-
v
228+
$(set_child!)(n, v, 2)
207229
else
208230
T = fieldtype(typeof(n), k)
209231
if v isa T
@@ -222,12 +244,6 @@ end
222244
@make_accessors GraphNode
223245
# TODO: Disable the `.l` accessors eventually, once the codebase is fully generic
224246

225-
@inline children(node::AbstractNode) = node.children
226-
@inline function children(node::AbstractNode, ::Val{n}) where {n}
227-
cs = children(node)
228-
return ntuple(i -> cs[i], Val(Int(n)))
229-
end
230-
231247
################################################################################
232248
#! format: on
233249

@@ -273,11 +289,11 @@ include("base.jl")
273289
@inline function (::Type{N})(
274290
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator,
275291
) where {T1,N<:AbstractExpressionNode{T} where T,F}
276-
_children = if l !== nothing && r === nothing
277-
@assert children === nothing
292+
_children = if !isnothing(l) && isnothing(r)
293+
@assert isnothing(children)
278294
(l,)
279-
elseif l !== nothing && r !== nothing
280-
@assert children === nothing
295+
elseif !isnothing(l) && !isnothing(r)
296+
@assert isnothing(children)
281297
(l, r)
282298
else
283299
children
@@ -328,8 +344,7 @@ end
328344
n = allocator(N, T)
329345
n.degree = D2
330346
n.op = op
331-
poison = get_poison(n)
332-
n.children = ntuple(i -> i <= D2 ? convert(NT, children[i]) : poison, Val(max_degree(N)))
347+
set_children!(n, children)
333348
return n
334349
end
335350

@@ -398,7 +413,7 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod
398413
end
399414
else
400415
tree.op = new_tree.op
401-
tree.children = new_tree.children
416+
set_children!(tree, get_children(new_tree))
402417
end
403418
return nothing
404419
end

src/NodePreallocation.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using ..NodeModule:
77
leaf_copy,
88
branch_copy,
99
set_node!,
10-
get_poison
10+
set_children!
1111

1212
"""
1313
allocate_container(prototype::AbstractExpressionNode, n=nothing)
@@ -60,8 +60,7 @@ function branch_copy_into!(
6060
) where {T,D,N<:AbstractExpressionNode{T,D},M}
6161
dest.degree = M
6262
dest.op = src.op
63-
poison = get_poison(dest)
64-
dest.children = ntuple(i -> i <= M ? children[i] : poison, D)
63+
set_children!(dest, children)
6564
return dest
6665
end
6766

src/NodeUtils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ..NodeModule:
66
Node,
77
preserve_sharing,
88
constructorof,
9-
get_poison,
9+
set_children!,
1010
copy_node,
1111
count_nodes,
1212
tree_mapreduce,
@@ -156,10 +156,9 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D}
156156
::Type{_T}, ::Val{_D}, child::NodeIndex{_T,_D}, childs::Vararg{NodeIndex{_T,_D},_D2}
157157
) where {_T,_D,_D2}
158158
node = NodeIndex(_T, Val(_D))
159-
poison = get_poison(node)
160159
children = (child, childs...)
161160
node.degree = _D2 + 1
162-
node.children = ntuple(i -> i <= _D2 + 1 ? children[i] : poison, Val(_D))
161+
set_children!(node, children)
163162
return node
164163
end
165164
end

src/ReadOnlyNode.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ReadOnlyNodeModule
33
using DispatchDoctor: @unstable
44

55
using ..NodeModule: AbstractExpressionNode, Node
6-
import ..NodeModule: default_allocator, with_type_parameters, constructorof, children
6+
import ..NodeModule: default_allocator, with_type_parameters, constructorof, get_children
77

88
abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D},IS_REF} <:
99
AbstractExpressionNode{T,D} end
@@ -38,8 +38,8 @@ Base.getindex(n::AbstractReadOnlyNode{T,D,N,true} where {T,D,N}) = n
3838
return out
3939
end
4040
end
41-
@inline function children(node::AbstractReadOnlyNode, ::Val{n}) where {n}
42-
return map(ReadOnlyNode, children(inner(node), Val(n)))
41+
@inline function get_children(node::AbstractReadOnlyNode)
42+
return map(ReadOnlyNode, get_children(inner(node)))
4343
end
4444
function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v)
4545
return error("Cannot set properties on a ReadOnlyNode")

src/base.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
Base.Cartesian.@nif(
138138
$D,
139139
i -> i == d,
140-
i -> let cs = children(tree, Val(i))
140+
i -> let cs = get_children(tree, Val(i))
141141
Base.Cartesian.@ncall(
142142
i,
143143
mapreducer.op,
@@ -182,7 +182,7 @@ By using this instead of tree_mapreduce, we can take advantage of early exits.
182182

183183
return (
184184
@inline(f(tree)) || Base.Cartesian.@nif(
185-
$D, i -> deg == i, i -> let cs = children(tree, Val(i))
185+
$D, i -> deg == i, i -> let cs = get_children(tree, Val(i))
186186
Base.Cartesian.@nany(i, j -> any(f, cs[j]))
187187
end
188188
)
@@ -226,9 +226,12 @@ end
226226
branch_equal(a, b) && Base.Cartesian.@nif(
227227
$D,
228228
i -> deg == i,
229-
i -> let cs_a = children(a, Val(i)), cs_b = children(b, Val(i))
230-
Base.Cartesian.@nall(i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps))
231-
end
229+
i ->
230+
let cs_a = get_children(a, Val(i)), cs_b = get_children(b, Val(i))
231+
Base.Cartesian.@nall(
232+
i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps)
233+
)
234+
end
232235
)
233236
)
234237
end

test/test_n_arity_nodes.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
@test n_bin.children[1] === n_bin_leaf1
3333
@test n_bin.children[2] === n_bin_leaf2
3434
@test n_bin.children[3] === n_bin # Poison
35-
@test DynamicExpressions.NodeModule.children(n_bin, Val(2)) ==
35+
@test DynamicExpressions.NodeModule.get_children(n_bin, Val(2)) ==
3636
(n_bin_leaf1, n_bin_leaf2)
3737
# .l and .r should work for Node{T,3} due to general @make_accessors Node
3838
@test n_bin.l === n_bin_leaf1
@@ -49,7 +49,7 @@
4949
@test n_ter.children[1] === n_ter_leaf1
5050
@test n_ter.children[2] === n_ter_leaf2
5151
@test n_ter.children[3] === n_ter_leaf3
52-
@test DynamicExpressions.NodeModule.children(n_ter, Val(3)) ==
52+
@test DynamicExpressions.NodeModule.get_children(n_ter, Val(3)) ==
5353
(n_ter_leaf1, n_ter_leaf2, n_ter_leaf3)
5454
@test n_ter.l === n_ter_leaf1
5555
@test n_ter.r === n_ter_leaf2
@@ -461,7 +461,7 @@ end
461461
@test readonly_tree.degree == 3
462462
@test readonly_tree.op == 1
463463

464-
ro_children = DynamicExpressions.NodeModule.children(readonly_tree, Val(3))
464+
ro_children = DynamicExpressions.NodeModule.get_children(readonly_tree, Val(3))
465465
@test length(ro_children) == 3
466466
@test ro_children[1] isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode
467467
@test ro_children[1].feature == 1

0 commit comments

Comments
 (0)