11module NodeModule
22
33using DispatchDoctor: @unstable
4- using StaticArrays: SizedVector
54
65import .. OperatorEnumModule: AbstractOperatorEnum
76import .. UtilsModule: @memoize_on , @with_memoize , deprecate_varmap, Undefined
87
98const DEFAULT_NODE_TYPE = Float32
109
1110"""
12- AbstractNode{D,shared }
11+ AbstractNode{D}
1312
14- Abstract type for D-arity trees. If `shared`, the node type
15- permits graph-like structures. Must have the following fields:
13+ Abstract type for D-arity trees. Must have the following fields:
1614
1715- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
1816 then `l` needs to be defined as the left child. If 2,
1917 then `r` also needs to be defined as the right child.
20- - `children`: A collection of D children nodes.
18+ - `children`: A collection of D references to children nodes.
2119
2220# Deprecated fields
2321
@@ -29,7 +27,7 @@ permits graph-like structures. Must have the following fields:
2927- `r::AbstractNode{D}`: Right child of the current node. Should only
3028 be defined if `degree == 2`.
3129"""
32- abstract type AbstractNode{D,shared } end
30+ abstract type AbstractNode{D} end
3331
3432"""
3533 AbstractExpressionNode{T,D} <: AbstractNode{D}
@@ -73,25 +71,27 @@ You likely do not need to, but you could choose to override the following:
7371- `with_type_parameters`
7472
7573"""
76- abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end
77-
78- mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared}
79- degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
80- constant:: Bool # false if variable
81- val:: T # If is a constant, this stores the actual value
82- feature:: UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
83- op:: UInt8 # If operator, this is the index of the operator in the degree-specific operator enum
84- children:: SizedVector{D,GeneralNode{T,D,shared}} # Children nodes
85-
86- # ################
87- # # Constructors:
88- # ################
89- GeneralNode {_T,_D,_shared} () where {_T,_D,_shared} = new {_T,_D,_shared} ()
74+ abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end
75+
76+ for N in (:Node , :GraphNode )
77+ @eval mutable struct $ N{T,D} <: AbstractExpressionNode{T,D}
78+ degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
79+ constant:: Bool # false if variable
80+ val:: T # If is a constant, this stores the actual value
81+ feature:: UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
82+ op:: UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
83+ children:: NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes
84+
85+ # ################
86+ # # Constructors:
87+ # ################
88+ $ N {_T,_D} () where {_T,_D} = new {_T,_D} ()
89+ end
9090end
9191
9292# ! format: off
9393"""
94- Node{T} <: AbstractExpressionNode{T,2 }
94+ Node{T,D } <: AbstractExpressionNode{T,D }
9595
9696Node defines a symbolic expression stored in a binary tree.
9797A single `Node` instance is one "node" of this tree, and
@@ -113,7 +113,7 @@ nodes, you can evaluate or print a given expression.
113113 operator in `operators.binops`. In other words, this is an enum
114114 of the operators, and is dependent on the specific `OperatorEnum`
115115 object. Only defined if `degree >= 1`
116- - `children::SizedArray {D,Node{T,D}}`: Children of the node. Only defined up to `degree`
116+ - `children::NTuple {D,Base.RefValue{ Node{T,D} }}`: Children of the node. Only defined up to `degree`
117117
118118# Constructors
119119
@@ -130,13 +130,13 @@ You may also construct nodes via the convenience operators generated by creating
130130You may also choose to specify a default memory allocator for the node other than simply `Node{T}()`
131131in the `allocator` keyword argument.
132132"""
133- const Node{T} = GeneralNode{T, 2 , false }
133+ Node
134134
135135
136136"""
137- GraphNode{T} <: AbstractExpressionNode{T}
137+ GraphNode{T,D } <: AbstractExpressionNode{T,D }
138138
139- Exactly the same as [`Node{T}`](@ref), but with the assumption that some
139+ Exactly the same as [`Node{T,D }`](@ref), but with the assumption that some
140140nodes will be shared. All copies of this graph-like structure will
141141be performed with this assumption, to preserve structure of the graph.
142142
@@ -164,7 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes
164164are created simply by using the same node in multiple places
165165when constructing or setting properties.
166166"""
167- const GraphNode{T} = GeneralNode{T, 2 , true }
167+ GraphNode
168168
169169# ###############################################################################
170170# ! format: on
@@ -177,30 +177,41 @@ function max_degree(::Type{N}) where {N<:AbstractExpressionNode}
177177end
178178
179179@unstable constructorof (:: Type{<:Node} ) = Node
180+ @unstable constructorof (:: Type{<:Node{T,D} where T} ) where {D} = Node{T,D} where T
180181@unstable constructorof (:: Type{<:GraphNode} ) = GraphNode
182+ @unstable constructorof (:: Type{<:GraphNode{T,D} where T} ) where {D} = GraphNode{T,D} where T
181183
182- with_type_parameters (:: Type{<:Node} , :: Type{T} ) where {T} = Node{T}
183- with_type_parameters (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode{T}
184+ with_type_parameters (:: Type{<:Node} , :: Type{T} , :: Val{D} = Val ( 2 )) where {T,D } = Node{T,D }
185+ with_type_parameters (:: Type{<:GraphNode} , :: Type{T} , :: Val{D} = Val ( 2 )) where {T,D } = GraphNode{T,D }
184186
185- default_allocator (:: Type{<:Node} , :: Type{T} ) where {T} = Node {T} ()
186- default_allocator (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode {T} ()
187+ default_allocator (:: Type{<:Node} , :: Type{T} , :: Val{D} = Val ( 2 )) where {T,D } = Node {T,D } ()
188+ default_allocator (:: Type{<:GraphNode} , :: Type{T} , :: Val{D} = Val ( 2 )) where {T,D } = GraphNode {T,D } ()
187189
188190""" Trait declaring whether nodes share children or not."""
189191preserve_sharing (:: Union{Type{<:AbstractNode},AbstractNode} ) = false
190- function preserve_sharing (
191- :: Union{Type{<:G},G}
192- ) where {shared,G<: GeneralNode{T,D,shared} where {T,D}}
193- return shared
194- end
192+ preserve_sharing (:: Union{Type{<:GraphNode},GraphNode} ) = true
195193
196194include (" base.jl" )
197195
198196# ! format: off
199197@inline function (:: Type{N} )(
200- :: Type{T1} = Undefined; val = nothing , feature = nothing , op = nothing , children = nothing , allocator :: F = default_allocator,
198+ :: Type{T1} = Undefined; kws ...
201199) where {T1,N<: AbstractExpressionNode ,F}
202- validate_not_all_defaults (N, val, feature, op, children)
203- return node_factory (N, T1, val, feature, op, children, allocator)
200+ end
201+ @inline function (:: Type{N} )(
202+ :: Type{T1} = Undefined; val= nothing , feature= nothing , op= nothing , l= nothing , r= nothing , children= nothing , allocator:: F = default_allocator,
203+ ) where {T1,D,N<: AbstractExpressionNode{T,D} where T,F}
204+ _children = if l != = nothing && r === nothing
205+ @assert children === nothing
206+ (l,)
207+ elseif l != = nothing && r != = nothing
208+ @assert children === nothing
209+ (l, r)
210+ else
211+ children
212+ end
213+ validate_not_all_defaults (N, val, feature, op, _children)
214+ return node_factory (N, T1, val, feature, op, _children, allocator)
204215end
205216function validate_not_all_defaults (:: Type{N} , val, feature, op, children) where {N<: AbstractExpressionNode }
206217 return nothing
@@ -209,7 +220,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, children) where
209220 if val === nothing && feature === nothing && op === nothing && children === nothing
210221 error (
211222 " Encountered the call for $N () inside the generic constructor. "
212- * " Did you forget to define `$(Base. typename (N). wrapper) {T}() where {T} = new{T}()`?"
223+ * " Did you forget to define `$(Base. typename (N). wrapper) {T,D }() where {T,D } = new{T,D }()`?"
213224 )
214225 end
215226 return nothing
219230 :: Type{N} , :: Type{T1} , val:: T2 , :: Nothing , :: Nothing , :: Nothing , allocator:: F ,
220231) where {N,T1,T2,F}
221232 T = node_factory_type (N, T1, T2)
222- n = allocator (N, T)
233+ n = allocator (N, T, D )
223234 n. degree = 0
224235 n. constant = true
225236 n. val = convert (T, val)
230241 :: Type{N} , :: Type{T1} , :: Nothing , feature:: Integer , :: Nothing , :: Nothing , allocator:: F ,
231242) where {N,T1,F}
232243 T = node_factory_type (N, T1, DEFAULT_NODE_TYPE)
233- n = allocator (N, T)
244+ n = allocator (N, T, D )
234245 n. degree = 0
235246 n. constant = false
236247 n. feature = feature
@@ -239,19 +250,14 @@ end
239250""" Create an operator node."""
240251@inline function node_factory (
241252 :: Type{N} , :: Type , :: Nothing , :: Nothing , op:: Integer , children:: NTuple{D2} , allocator:: F ,
242- ) where {N,F,D2}
243- D = max_degree (N)
244- @assert D2 <= D
253+ ) where {D,N<: AbstractExpressionNode{T where T,D} ,F,D2}
245254 T = promote_type (map (eltype, children)... ) # Always prefer existing nodes, so we don't mess up references from conversion
246- NT = with_type_parameters (N, T)
247- n = allocator (N, T)
255+ NT = with_type_parameters (N, T, D )
256+ n = allocator (N, T, D )
248257 n. degree = D2
249258 n. op = op
250- ar = SizedVector {D,NT} (undef)
251- for i in 1 : D2
252- ar[i] = children[i]
253- end
254- n. children = ar
259+ n. children
260+ # map(Ref, children)
255261 return n
256262end
257263
0 commit comments