@@ -3,34 +3,30 @@ module NodeModule
33using DispatchDoctor: @unstable
44
55import .. OperatorEnumModule: AbstractOperatorEnum
6- import .. UtilsModule: deprecate_varmap, Undefined
6+ import .. UtilsModule: @memoize_on , @with_memoize , deprecate_varmap, Undefined
77
88const DEFAULT_NODE_TYPE = Float32
99
1010"""
11- AbstractNode{D}
11+ AbstractNode
1212
13- Abstract type for D-arity trees. Must have the following fields:
13+ Abstract type for binary trees. Must have the following fields:
1414
1515- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
1616 then `l` needs to be defined as the left child. If 2,
1717 then `r` also needs to be defined as the right child.
18- - `children`: A collection of D references to children nodes.
19-
20- # Deprecated fields
21-
22- - `l::AbstractNode{D}`: Left child of the current node. Should only be
18+ - `l::AbstractNode`: Left child of the current node. Should only be
2319 defined if `degree >= 1`; otherwise, leave it undefined (see the
2420 the constructors of [`Node{T}`](@ref) for an example).
2521 Don't use `nothing` to represent an undefined value
2622 as it will incur a large performance penalty.
27- - `r::AbstractNode{D} `: Right child of the current node. Should only
23+ - `r::AbstractNode`: Right child of the current node. Should only
2824 be defined if `degree == 2`.
2925"""
30- abstract type AbstractNode{D} end
26+ abstract type AbstractNode end
3127
3228"""
33- AbstractExpressionNode{T,D } <: AbstractNode{D}
29+ AbstractExpressionNode{T} <: AbstractNode
3430
3531Abstract type for nodes that represent an expression.
3632Along with the fields required for `AbstractNode`,
@@ -71,27 +67,11 @@ You likely do not need to, but you could choose to override the following:
7167- `with_type_parameters`
7268
7369"""
74- abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end
75-
76- for N in (:Node , :GraphNode )
77- @eval mutable struct $ N{T,D} <: AbstractExpressionNode{T,D}
78- degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
79- constant:: Bool # false if variable
80- val:: T # If is a constant, this stores the actual value
81- feature:: UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
82- op:: UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
83- children:: NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes
84-
85- # ################
86- # # Constructors:
87- # ################
88- $ N {_T,_D} () where {_T,_D} = new {_T,_D::Int} ()
89- end
90- end
70+ abstract type AbstractExpressionNode{T} <: AbstractNode end
9171
9272# ! format: off
9373"""
94- Node{T,D } <: AbstractExpressionNode{T,D }
74+ Node{T} <: AbstractExpressionNode{T}
9575
9676Node defines a symbolic expression stored in a binary tree.
9777A single `Node` instance is one "node" of this tree, and
@@ -101,42 +81,63 @@ nodes, you can evaluate or print a given expression.
10181# Fields
10282
10383- `degree::UInt8`: Degree of the node. 0 for constants, 1 for
104- unary operators, 2 for binary operators, etc. Maximum of `D` .
84+ unary operators, 2 for binary operators.
10585- `constant::Bool`: Whether the node is a constant.
10686- `val::T`: Value of the node. If `degree==0`, and `constant==true`,
10787 this is the value of the constant. It has a type specified by the
10888 overall type of the `Node` (e.g., `Float64`).
10989- `feature::UInt16`: Index of the feature to use in the
110- case of a feature node. Only defined if `degree == 0 && constant == false`.
90+ case of a feature node. Only used if `degree==0` and `constant==false`.
91+ Only defined if `degree == 0 && constant == false`.
11192- `op::UInt8`: If `degree==1`, this is the index of the operator
11293 in `operators.unaops`. If `degree==2`, this is the index of the
11394 operator in `operators.binops`. In other words, this is an enum
11495 of the operators, and is dependent on the specific `OperatorEnum`
11596 object. Only defined if `degree >= 1`
116- - `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree`
97+ - `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`.
98+ Same type as the parent node.
99+ - `r::Node{T}`: Right child of the node. Only defined if `degree == 2`.
100+ Same type as the parent node. This is to be passed as the right
101+ argument to the binary operator.
117102
118103# Constructors
119104
120105
121- Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
122- Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
106+ Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
107+ Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
123108
124109Create a new node in an expression tree. If `T` is not specified in either the type or the
125- first argument, it will be inferred from the value of `val` passed or the children.
126- The `children` keyword is used to pass in a collection of children nodes.
110+ first argument, it will be inferred from the value of `val` passed or `l` and/or `r`.
111+ If it cannot be inferred from these, it will default to `Float32`.
112+
113+ The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This
114+ is to permit the use of splatting in constructors.
127115
128116You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`.
129117
130118You may also choose to specify a default memory allocator for the node other than simply `Node{T}()`
131119in the `allocator` keyword argument.
132120"""
133- Node
134-
121+ mutable struct Node{T} <: AbstractExpressionNode{T}
122+ degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
123+ constant:: Bool # false if variable
124+ val:: T # If is a constant, this stores the actual value
125+ # ------------------- (possibly undefined below)
126+ feature:: UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
127+ op:: UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
128+ l:: Node{T} # Left child node. Only defined for degree=1 or degree=2.
129+ r:: Node{T} # Right child node. Only defined for degree=2.
130+
131+ # ################
132+ # # Constructors:
133+ # ################
134+ Node {_T} () where {_T} = new {_T} ()
135+ end
135136
136137"""
137- GraphNode{T,D } <: AbstractExpressionNode{T,D }
138+ GraphNode{T} <: AbstractExpressionNode{T}
138139
139- Exactly the same as [`Node{T,D }`](@ref), but with the assumption that some
140+ Exactly the same as [`Node{T}`](@ref), but with the assumption that some
140141nodes will be shared. All copies of this graph-like structure will
141142be performed with this assumption, to preserve structure of the graph.
142143
@@ -145,7 +146,7 @@ be performed with this assumption, to preserve structure of the graph.
145146```julia
146147julia> operators = OperatorEnum(;
147148 binary_operators=[+, -, *], unary_operators=[cos, sin]
148- );
149+ );
149150
150151julia> x = GraphNode(feature=1)
151152x1
@@ -164,38 +165,17 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes
164165are created simply by using the same node in multiple places
165166when constructing or setting properties.
166167"""
167- GraphNode
168-
169- @inline function Base. getproperty (n:: Union{Node,GraphNode} , k:: Symbol )
170- if k == :l
171- # TODO : Should a depwarn be raised here? Or too slow?
172- return getfield (n, :children )[1 ][]
173- elseif k == :r
174- return getfield (n, :children )[2 ][]
175- else
176- return getfield (n, k)
177- end
178- end
179- @inline function Base. setproperty! (n:: Union{Node,GraphNode} , k:: Symbol , v)
180- if k == :l
181- getfield (n, :children )[1 ][] = v
182- elseif k == :r
183- getfield (n, :children )[2 ][] = v
184- elseif k == :degree
185- setfield! (n, :degree , convert (UInt8, v))
186- elseif k == :constant
187- setfield! (n, :constant , convert (Bool, v))
188- elseif k == :feature
189- setfield! (n, :feature , convert (UInt16, v))
190- elseif k == :op
191- setfield! (n, :op , convert (UInt8, v))
192- elseif k == :val
193- setfield! (n, :val , convert (eltype (n), v))
194- elseif k == :children
195- setfield! (n, :children , v)
196- else
197- error (" Invalid property: $k " )
198- end
168+ mutable struct GraphNode{T} <: AbstractExpressionNode{T}
169+ degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
170+ constant:: Bool # false if variable
171+ val:: T # If is a constant, this stores the actual value
172+ # ------------------- (possibly undefined below)
173+ feature:: UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
174+ op:: UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
175+ l:: GraphNode{T} # Left child node. Only defined for degree=1 or degree=2.
176+ r:: GraphNode{T} # Right child node. Only defined for degree=2.
177+
178+ GraphNode {_T} () where {_T} = new {_T} ()
199179end
200180
201181# ###############################################################################
@@ -204,62 +184,59 @@ end
204184Base. eltype (:: Type{<:AbstractExpressionNode{T}} ) where {T} = T
205185Base. eltype (:: AbstractExpressionNode{T} ) where {T} = T
206186
207- max_degree (:: Type{<:AbstractNode} ) = 2 # Default
208- max_degree (:: Type{<:AbstractNode{D}} ) where {D} = D
209-
210- @unstable constructorof (:: Type{N} ) where {N<: Node } = Node{T,max_degree (N)} where {T}
211- @unstable constructorof (:: Type{N} ) where {N<: GraphNode } =
212- GraphNode{T,max_degree (N)} where {T}
187+ @unstable constructorof (:: Type{N} ) where {N<: AbstractNode } = Base. typename (N). wrapper
188+ @unstable constructorof (:: Type{<:Node} ) = Node
189+ @unstable constructorof (:: Type{<:GraphNode} ) = GraphNode
213190
214- with_type_parameters (:: Type{N} , :: Type{T} ) where {N<: Node ,T} = Node{T,max_degree (N)}
215- function with_type_parameters (:: Type{N} , :: Type{T} ) where {N<: GraphNode ,T}
216- return GraphNode{T,max_degree (N)}
191+ function with_type_parameters (:: Type{N} , :: Type{T} ) where {N<: AbstractExpressionNode ,T}
192+ return constructorof (N){T}
217193end
218-
219- # with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
220- # with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
194+ with_type_parameters (:: Type{<:Node} , :: Type{T} ) where {T} = Node{T}
195+ with_type_parameters (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode{T}
221196
222197function default_allocator (:: Type{N} , :: Type{T} ) where {N<: AbstractExpressionNode ,T}
223198 return with_type_parameters (N, T)()
224199end
200+ default_allocator (:: Type{<:Node} , :: Type{T} ) where {T} = Node {T} ()
201+ default_allocator (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode {T} ()
225202
226203""" Trait declaring whether nodes share children or not."""
227204preserve_sharing (:: Union{Type{<:AbstractNode},AbstractNode} ) = false
205+ preserve_sharing (:: Union{Type{<:Node},Node} ) = false
228206preserve_sharing (:: Union{Type{<:GraphNode},GraphNode} ) = true
229207
230208include (" base.jl" )
231209
232210# ! format: off
233211@inline function (:: Type{N} )(
234212 :: Type{T1} = Undefined; val= nothing , feature= nothing , op= nothing , l= nothing , r= nothing , children= nothing , allocator:: F = default_allocator,
235- ) where {T1,N<: AbstractExpressionNode{T} where T ,F}
236- _children = if l != = nothing && r === nothing
237- @assert children = == nothing
238- (l,)
239- elseif l != = nothing && r ! == nothing
240- @assert children === nothing
241- (l, r)
242- else
243- children
213+ ) where {T1,N<: AbstractExpressionNode ,F}
214+ validate_not_all_defaults (N, val, feature, op, l, r, children)
215+ if children ! == nothing
216+ @assert l === nothing && r === nothing
217+ if length (children) == 1
218+ return node_factory (N, T1, val, feature, op, only (children), nothing , allocator)
219+ else
220+ return node_factory (N, T1, val, feature, op, children ... , allocator)
221+ end
244222 end
245- validate_not_all_defaults (N, val, feature, op, _children)
246- return node_factory (N, T1, val, feature, op, _children, allocator)
223+ return node_factory (N, T1, val, feature, op, l, r, allocator)
247224end
248- function validate_not_all_defaults (:: Type{N} , val, feature, op, children) where {N<: AbstractExpressionNode }
225+ function validate_not_all_defaults (:: Type{N} , val, feature, op, l, r, children) where {N<: AbstractExpressionNode }
249226 return nothing
250227end
251- function validate_not_all_defaults (:: Type{N} , val, feature, op, children) where {T,N<: AbstractExpressionNode{T} }
252- if val === nothing && feature === nothing && op === nothing && children === nothing
228+ function validate_not_all_defaults (:: Type{N} , val, feature, op, l, r, children) where {T,N<: AbstractExpressionNode{T} }
229+ if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing
253230 error (
254231 " Encountered the call for $N () inside the generic constructor. "
255- * " Did you forget to define `$(Base. typename (N). wrapper) {T,D }() where {T,D } = new{T,D }()`?"
232+ * " Did you forget to define `$(Base. typename (N). wrapper) {T}() where {T} = new{T}()`?"
256233 )
257234 end
258235 return nothing
259236end
260237""" Create a constant leaf."""
261238@inline function node_factory (
262- :: Type{N} , :: Type{T1} , val:: T2 , :: Nothing , :: Nothing , :: Nothing , allocator:: F ,
239+ :: Type{N} , :: Type{T1} , val:: T2 , :: Nothing , :: Nothing , :: Nothing , :: Nothing , allocator:: F ,
263240) where {N,T1,T2,F}
264241 T = node_factory_type (N, T1, T2)
265242 n = allocator (N, T)
270247end
271248""" Create a variable leaf, to store data."""
272249@inline function node_factory (
273- :: Type{N} , :: Type{T1} , :: Nothing , feature:: Integer , :: Nothing , :: Nothing , allocator:: F ,
250+ :: Type{N} , :: Type{T1} , :: Nothing , feature:: Integer , :: Nothing , :: Nothing , :: Nothing , allocator:: F ,
274251) where {N,T1,F}
275252 T = node_factory_type (N, T1, DEFAULT_NODE_TYPE)
276253 n = allocator (N, T)
@@ -279,18 +256,28 @@ end
279256 n. feature = feature
280257 return n
281258end
282- """ Create an operator node."""
259+ """ Create a unary operator node."""
260+ @inline function node_factory (
261+ :: Type{N} , :: Type{T1} , :: Nothing , :: Nothing , op:: Integer , l:: AbstractExpressionNode{T2} , :: Nothing , allocator:: F ,
262+ ) where {N,T1,T2,F}
263+ @assert l isa N
264+ T = T2 # Always prefer existing nodes, so we don't mess up references from conversion
265+ n = allocator (N, T)
266+ n. degree = 1
267+ n. op = op
268+ n. l = l
269+ return n
270+ end
271+ """ Create a binary operator node."""
283272@inline function node_factory (
284- :: Type{N} , :: Type , :: Nothing , :: Nothing , op:: Integer , children:: Tuple , allocator:: F ,
285- ) where {N<: AbstractExpressionNode ,F}
286- T = promote_type (map (eltype, children)... ) # Always prefer existing nodes, so we don't mess up references from conversion
287- D2 = length (children)
288- @assert D2 <= max_degree (N)
289- NT = with_type_parameters (N, T)
273+ :: Type{N} , :: Type{T1} , :: Nothing , :: Nothing , op:: Integer , l:: AbstractExpressionNode{T2} , r:: AbstractExpressionNode{T3} , allocator:: F ,
274+ ) where {N,T1,T2,T3,F}
275+ T = promote_type (T2, T3)
290276 n = allocator (N, T)
291- n. degree = D2
277+ n. degree = 2
292278 n. op = op
293- n. children = ntuple (i -> i <= D2 ? Ref (convert (NT, children[i])) : Ref {NT} (), Val (max_degree (N)))
279+ n. l = T2 === T ? l : convert (with_type_parameters (N, T), l)
280+ n. r = T3 === T ? r : convert (with_type_parameters (N, T), r)
294281 return n
295282end
296283
@@ -331,14 +318,14 @@ function (::Type{N})(
331318 return N (; feature= i)
332319end
333320
334- function Base. promote_rule (:: Type{Node{T1,D }} , :: Type{Node{T2,D }} ) where {T1,T2,D }
335- return Node{promote_type (T1, T2),D }
321+ function Base. promote_rule (:: Type{Node{T1}} , :: Type{Node{T2}} ) where {T1,T2}
322+ return Node{promote_type (T1, T2)}
336323end
337- function Base. promote_rule (:: Type{GraphNode{T1,D }} , :: Type{Node{T2,D }} ) where {T1,T2,D }
338- return GraphNode{promote_type (T1, T2),D }
324+ function Base. promote_rule (:: Type{GraphNode{T1}} , :: Type{Node{T2}} ) where {T1,T2}
325+ return GraphNode{promote_type (T1, T2)}
339326end
340- function Base. promote_rule (:: Type{GraphNode{T1,D }} , :: Type{GraphNode{T2,D }} ) where {T1,T2,D }
341- return GraphNode{promote_type (T1, T2),D }
327+ function Base. promote_rule (:: Type{GraphNode{T1}} , :: Type{GraphNode{T2}} ) where {T1,T2}
328+ return GraphNode{promote_type (T1, T2)}
342329end
343330
344331# TODO : Verify using this helps with garbage collection
0 commit comments