@@ -2,7 +2,7 @@ module NodeModule
22
33using DispatchDoctor: @unstable
44
5- import .. UtilsModule: deprecate_varmap, Undefined
5+ import .. UtilsModule: deprecate_varmap, Undefined, Nullable
66
77const DEFAULT_NODE_TYPE = Float32
88const DEFAULT_MAX_DEGREE = 2
@@ -78,7 +78,7 @@ for N in (:Node, :GraphNode)
7878 val:: T # If is a constant, this stores the actual value
7979 feature:: UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
8080 op:: UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
81- children:: NTuple{D,$N{T,D}}
81+ children:: NTuple{D,Nullable{ $N{T,D} }}
8282
8383 # ################
8484 # # Constructors:
@@ -173,13 +173,7 @@ Accessing this node should trigger some kind of noticable error
173173(e.g., default returns itself, which causes infinite recursion).
174174"""
175175function poison_node (n:: AbstractNode )
176- # We don't want to use `nothing` because the type instability
177- # hits memory hard.
178- # Setting itself as the right child is the best thing,
179- # because it (1) doesn't allocate, and (2) will trigger
180- # infinite recursion errors if someone is mistakenly trying
181- # to access the right child when `.degree == 1`.
182- return n
176+ return Nullable (true , n)
183177end
184178
185179"""
@@ -190,7 +184,7 @@ children may be "poisoned" nodes which you should not access,
190184as they will trigger infinite recursion errors. Ensure to
191185only access children only up to the `.degree` of the node.
192186"""
193- @inline function get_children (node:: AbstractNode )
187+ @inline function unsafe_get_children (node:: AbstractNode )
194188 return getfield (node, :children )
195189end
196190
@@ -207,8 +201,8 @@ for total type stability.
207201 return get_children (node, Val (n))
208202end
209203@inline function get_children (node:: AbstractNode{D} , :: Val{n} ) where {D,n}
210- cs = get_children (node)
211- return ntuple (i -> cs[i], Val (n))
204+ cs = unsafe_get_children (node)
205+ return ntuple (i -> cs[i][] , Val (n))
212206end
213207
214208"""
217211Return the `i`-th child of a node (1-indexed).
218212"""
219213@inline function get_child (n:: AbstractNode{D} , i:: Int ) where {D}
220- return get_children (n)[i]
214+ return unsafe_get_children (n)[i][ ]
221215end
222216
223217"""
@@ -227,7 +221,7 @@ Replace the `i`-th child of a node (1-indexed) with the given child node.
227221Returns the new child. Updates the children tuple in-place.
228222"""
229223@inline function set_child! (n:: AbstractNode{D} , child:: AbstractNode{D} , i:: Int ) where {D}
230- set_children! (n, Base. setindex (get_children (n), child, i))
224+ set_children! (n, Base. setindex (unsafe_get_children (n), Nullable ( false , child) , i))
231225 return child
232226end
233227
@@ -242,17 +236,21 @@ provided than the node's maximum degree, remaining slots are filled with poison
242236) where {D}
243237 D2 = length (children)
244238 if D === D2
245- n. children = children
239+ n. children = ntuple (i -> _ensure_nullable ( @inbounds ( children[i])), Val (D))
246240 else
247241 poison = poison_node (n)
248242 # We insert poison at the end of the tuple so that
249243 # errors will appear loudly if accessed.
250244 # This poison should be efficient to insert. So
251- # for simplicity, we can just use poison == n, which
252- # will trigger infinite recursion errors if accessed.
253- n. children = ntuple (i -> i <= D2 ? @inbounds (children[i]) : poison, Val (D))
245+ # for simplicity, we can just use poison := Nullable(true, n)
246+ # which will raise an UndefRefError if accessed.
247+ n. children = ntuple (
248+ i -> i <= D2 ? _ensure_nullable (@inbounds (children[i])) : poison, Val (D)
249+ )
254250 end
255251end
252+ @inline _ensure_nullable (x) = Nullable (false , x)
253+ @inline _ensure_nullable (x:: Nullable ) = x
256254
257255macro make_accessors (node_type)
258256 esc (
@@ -491,7 +489,7 @@ function set_leaf!(tree::AbstractExpressionNode, new_leaf::AbstractExpressionNod
491489end
492490function set_branch! (tree:: AbstractExpressionNode , new_branch:: AbstractExpressionNode )
493491 tree. op = new_branch. op
494- set_children! (tree, get_children (new_branch))
492+ set_children! (tree, unsafe_get_children (new_branch))
495493 return nothing
496494end
497495
0 commit comments