Skip to content

Commit 6146408

Browse files
committed
wip: undo all changes
1 parent 1e672bc commit 6146408

File tree

7 files changed

+251
-225
lines changed

7 files changed

+251
-225
lines changed

src/DynamicExpressions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ import .NodeModule:
4747
constructorof,
4848
with_type_parameters,
4949
preserve_sharing,
50-
max_degree,
5150
leaf_copy,
5251
branch_copy,
5352
leaf_hash,

src/Node.jl

Lines changed: 104 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,30 @@ module NodeModule
33
using DispatchDoctor: @unstable
44

55
import ..OperatorEnumModule: AbstractOperatorEnum
6-
import ..UtilsModule: deprecate_varmap, Undefined
6+
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
77

88
const DEFAULT_NODE_TYPE = Float32
99

1010
"""
11-
AbstractNode{D}
11+
AbstractNode
1212
13-
Abstract type for D-arity trees. Must have the following fields:
13+
Abstract type for binary trees. Must have the following fields:
1414
1515
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
1616
then `l` needs to be defined as the left child. If 2,
1717
then `r` also needs to be defined as the right child.
18-
- `children`: A collection of D references to children nodes.
19-
20-
# Deprecated fields
21-
22-
- `l::AbstractNode{D}`: Left child of the current node. Should only be
18+
- `l::AbstractNode`: Left child of the current node. Should only be
2319
defined if `degree >= 1`; otherwise, leave it undefined (see the
2420
the constructors of [`Node{T}`](@ref) for an example).
2521
Don't use `nothing` to represent an undefined value
2622
as it will incur a large performance penalty.
27-
- `r::AbstractNode{D}`: Right child of the current node. Should only
23+
- `r::AbstractNode`: Right child of the current node. Should only
2824
be defined if `degree == 2`.
2925
"""
30-
abstract type AbstractNode{D} end
26+
abstract type AbstractNode end
3127

3228
"""
33-
AbstractExpressionNode{T,D} <: AbstractNode{D}
29+
AbstractExpressionNode{T} <: AbstractNode
3430
3531
Abstract type for nodes that represent an expression.
3632
Along with the fields required for `AbstractNode`,
@@ -71,27 +67,11 @@ You likely do not need to, but you could choose to override the following:
7167
- `with_type_parameters`
7268
7369
"""
74-
abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end
75-
76-
for N in (:Node, :GraphNode)
77-
@eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D}
78-
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
79-
constant::Bool # false if variable
80-
val::T # If is a constant, this stores the actual value
81-
feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
82-
op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
83-
children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes
84-
85-
#################
86-
## Constructors:
87-
#################
88-
$N{_T,_D}() where {_T,_D} = new{_T,_D::Int}()
89-
end
90-
end
70+
abstract type AbstractExpressionNode{T} <: AbstractNode end
9171

9272
#! format: off
9373
"""
94-
Node{T,D} <: AbstractExpressionNode{T,D}
74+
Node{T} <: AbstractExpressionNode{T}
9575
9676
Node defines a symbolic expression stored in a binary tree.
9777
A single `Node` instance is one "node" of this tree, and
@@ -101,42 +81,63 @@ nodes, you can evaluate or print a given expression.
10181
# Fields
10282
10383
- `degree::UInt8`: Degree of the node. 0 for constants, 1 for
104-
unary operators, 2 for binary operators, etc. Maximum of `D`.
84+
unary operators, 2 for binary operators.
10585
- `constant::Bool`: Whether the node is a constant.
10686
- `val::T`: Value of the node. If `degree==0`, and `constant==true`,
10787
this is the value of the constant. It has a type specified by the
10888
overall type of the `Node` (e.g., `Float64`).
10989
- `feature::UInt16`: Index of the feature to use in the
110-
case of a feature node. Only defined if `degree == 0 && constant == false`.
90+
case of a feature node. Only used if `degree==0` and `constant==false`.
91+
Only defined if `degree == 0 && constant == false`.
11192
- `op::UInt8`: If `degree==1`, this is the index of the operator
11293
in `operators.unaops`. If `degree==2`, this is the index of the
11394
operator in `operators.binops`. In other words, this is an enum
11495
of the operators, and is dependent on the specific `OperatorEnum`
11596
object. Only defined if `degree >= 1`
116-
- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree`
97+
- `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`.
98+
Same type as the parent node.
99+
- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`.
100+
Same type as the parent node. This is to be passed as the right
101+
argument to the binary operator.
117102
118103
# Constructors
119104
120105
121-
Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
122-
Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
106+
Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
107+
Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
123108
124109
Create a new node in an expression tree. If `T` is not specified in either the type or the
125-
first argument, it will be inferred from the value of `val` passed or the children.
126-
The `children` keyword is used to pass in a collection of children nodes.
110+
first argument, it will be inferred from the value of `val` passed or `l` and/or `r`.
111+
If it cannot be inferred from these, it will default to `Float32`.
112+
113+
The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This
114+
is to permit the use of splatting in constructors.
127115
128116
You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`.
129117
130118
You may also choose to specify a default memory allocator for the node other than simply `Node{T}()`
131119
in the `allocator` keyword argument.
132120
"""
133-
Node
134-
121+
mutable struct Node{T} <: AbstractExpressionNode{T}
122+
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
123+
constant::Bool # false if variable
124+
val::T # If is a constant, this stores the actual value
125+
# ------------------- (possibly undefined below)
126+
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
127+
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
128+
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
129+
r::Node{T} # Right child node. Only defined for degree=2.
130+
131+
#################
132+
## Constructors:
133+
#################
134+
Node{_T}() where {_T} = new{_T}()
135+
end
135136

136137
"""
137-
GraphNode{T,D} <: AbstractExpressionNode{T,D}
138+
GraphNode{T} <: AbstractExpressionNode{T}
138139
139-
Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some
140+
Exactly the same as [`Node{T}`](@ref), but with the assumption that some
140141
nodes will be shared. All copies of this graph-like structure will
141142
be performed with this assumption, to preserve structure of the graph.
142143
@@ -145,7 +146,7 @@ be performed with this assumption, to preserve structure of the graph.
145146
```julia
146147
julia> operators = OperatorEnum(;
147148
binary_operators=[+, -, *], unary_operators=[cos, sin]
148-
);
149+
);
149150
150151
julia> x = GraphNode(feature=1)
151152
x1
@@ -164,38 +165,17 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes
164165
are created simply by using the same node in multiple places
165166
when constructing or setting properties.
166167
"""
167-
GraphNode
168-
169-
@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol)
170-
if k == :l
171-
# TODO: Should a depwarn be raised here? Or too slow?
172-
return getfield(n, :children)[1][]
173-
elseif k == :r
174-
return getfield(n, :children)[2][]
175-
else
176-
return getfield(n, k)
177-
end
178-
end
179-
@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v)
180-
if k == :l
181-
getfield(n, :children)[1][] = v
182-
elseif k == :r
183-
getfield(n, :children)[2][] = v
184-
elseif k == :degree
185-
setfield!(n, :degree, convert(UInt8, v))
186-
elseif k == :constant
187-
setfield!(n, :constant, convert(Bool, v))
188-
elseif k == :feature
189-
setfield!(n, :feature, convert(UInt16, v))
190-
elseif k == :op
191-
setfield!(n, :op, convert(UInt8, v))
192-
elseif k == :val
193-
setfield!(n, :val, convert(eltype(n), v))
194-
elseif k == :children
195-
setfield!(n, :children, v)
196-
else
197-
error("Invalid property: $k")
198-
end
168+
mutable struct GraphNode{T} <: AbstractExpressionNode{T}
169+
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
170+
constant::Bool # false if variable
171+
val::T # If is a constant, this stores the actual value
172+
# ------------------- (possibly undefined below)
173+
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
174+
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
175+
l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2.
176+
r::GraphNode{T} # Right child node. Only defined for degree=2.
177+
178+
GraphNode{_T}() where {_T} = new{_T}()
199179
end
200180

201181
################################################################################
@@ -204,62 +184,59 @@ end
204184
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
205185
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
206186

207-
max_degree(::Type{<:AbstractNode}) = 2 # Default
208-
max_degree(::Type{<:AbstractNode{D}}) where {D} = D
209-
210-
@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T}
211-
@unstable constructorof(::Type{N}) where {N<:GraphNode} =
212-
GraphNode{T,max_degree(N)} where {T}
187+
@unstable constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper
188+
@unstable constructorof(::Type{<:Node}) = Node
189+
@unstable constructorof(::Type{<:GraphNode}) = GraphNode
213190

214-
with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)}
215-
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T}
216-
return GraphNode{T,max_degree(N)}
191+
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
192+
return constructorof(N){T}
217193
end
218-
219-
# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
220-
# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
194+
with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T}
195+
with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}
221196

222197
function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
223198
return with_type_parameters(N, T)()
224199
end
200+
default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}()
201+
default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}()
225202

226203
"""Trait declaring whether nodes share children or not."""
227204
preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false
205+
preserve_sharing(::Union{Type{<:Node},Node}) = false
228206
preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true
229207

230208
include("base.jl")
231209

232210
#! format: off
233211
@inline function (::Type{N})(
234212
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator,
235-
) where {T1,N<:AbstractExpressionNode{T} where T,F}
236-
_children = if l !== nothing && r === nothing
237-
@assert children === nothing
238-
(l,)
239-
elseif l !== nothing && r !== nothing
240-
@assert children === nothing
241-
(l, r)
242-
else
243-
children
213+
) where {T1,N<:AbstractExpressionNode,F}
214+
validate_not_all_defaults(N, val, feature, op, l, r, children)
215+
if children !== nothing
216+
@assert l === nothing && r === nothing
217+
if length(children) == 1
218+
return node_factory(N, T1, val, feature, op, only(children), nothing, allocator)
219+
else
220+
return node_factory(N, T1, val, feature, op, children..., allocator)
221+
end
244222
end
245-
validate_not_all_defaults(N, val, feature, op, _children)
246-
return node_factory(N, T1, val, feature, op, _children, allocator)
223+
return node_factory(N, T1, val, feature, op, l, r, allocator)
247224
end
248-
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode}
225+
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode}
249226
return nothing
250227
end
251-
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {T,N<:AbstractExpressionNode{T}}
252-
if val === nothing && feature === nothing && op === nothing && children === nothing
228+
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}}
229+
if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing
253230
error(
254231
"Encountered the call for $N() inside the generic constructor. "
255-
* "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?"
232+
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
256233
)
257234
end
258235
return nothing
259236
end
260237
"""Create a constant leaf."""
261238
@inline function node_factory(
262-
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F,
239+
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F,
263240
) where {N,T1,T2,F}
264241
T = node_factory_type(N, T1, T2)
265242
n = allocator(N, T)
@@ -270,7 +247,7 @@ end
270247
end
271248
"""Create a variable leaf, to store data."""
272249
@inline function node_factory(
273-
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F,
250+
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F,
274251
) where {N,T1,F}
275252
T = node_factory_type(N, T1, DEFAULT_NODE_TYPE)
276253
n = allocator(N, T)
@@ -279,18 +256,28 @@ end
279256
n.feature = feature
280257
return n
281258
end
282-
"""Create an operator node."""
259+
"""Create a unary operator node."""
260+
@inline function node_factory(
261+
::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F,
262+
) where {N,T1,T2,F}
263+
@assert l isa N
264+
T = T2 # Always prefer existing nodes, so we don't mess up references from conversion
265+
n = allocator(N, T)
266+
n.degree = 1
267+
n.op = op
268+
n.l = l
269+
return n
270+
end
271+
"""Create a binary operator node."""
283272
@inline function node_factory(
284-
::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F,
285-
) where {N<:AbstractExpressionNode,F}
286-
T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion
287-
D2 = length(children)
288-
@assert D2 <= max_degree(N)
289-
NT = with_type_parameters(N, T)
273+
::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F,
274+
) where {N,T1,T2,T3,F}
275+
T = promote_type(T2, T3)
290276
n = allocator(N, T)
291-
n.degree = D2
277+
n.degree = 2
292278
n.op = op
293-
n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N)))
279+
n.l = T2 === T ? l : convert(with_type_parameters(N, T), l)
280+
n.r = T3 === T ? r : convert(with_type_parameters(N, T), r)
294281
return n
295282
end
296283

@@ -331,14 +318,14 @@ function (::Type{N})(
331318
return N(; feature=i)
332319
end
333320

334-
function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
335-
return Node{promote_type(T1, T2),D}
321+
function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2}
322+
return Node{promote_type(T1, T2)}
336323
end
337-
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
338-
return GraphNode{promote_type(T1, T2),D}
324+
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2}
325+
return GraphNode{promote_type(T1, T2)}
339326
end
340-
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D}
341-
return GraphNode{promote_type(T1, T2),D}
327+
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2}
328+
return GraphNode{promote_type(T1, T2)}
342329
end
343330

344331
# TODO: Verify using this helps with garbage collection

0 commit comments

Comments
 (0)