Skip to content

Commit ab7c65c

Browse files
committed
wip
1 parent 6b90350 commit ab7c65c

File tree

2 files changed

+57
-52
lines changed

2 files changed

+57
-52
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1313
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1716
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1817

1918
[weakdeps]

src/Node.jl

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
module NodeModule
22

33
using DispatchDoctor: @unstable
4-
using StaticArrays: SizedVector
54

65
import ..OperatorEnumModule: AbstractOperatorEnum
76
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
87

98
const DEFAULT_NODE_TYPE = Float32
109

1110
"""
12-
AbstractNode{D,shared}
11+
AbstractNode{D}
1312
14-
Abstract type for D-arity trees. If `shared`, the node type
15-
permits graph-like structures. Must have the following fields:
13+
Abstract type for D-arity trees. Must have the following fields:
1614
1715
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
1816
then `l` needs to be defined as the left child. If 2,
1917
then `r` also needs to be defined as the right child.
20-
- `children`: A collection of D children nodes.
18+
- `children`: A collection of D references to children nodes.
2119
2220
# Deprecated fields
2321
@@ -29,7 +27,7 @@ permits graph-like structures. Must have the following fields:
2927
- `r::AbstractNode{D}`: Right child of the current node. Should only
3028
be defined if `degree == 2`.
3129
"""
32-
abstract type AbstractNode{D,shared} end
30+
abstract type AbstractNode{D} end
3331

3432
"""
3533
AbstractExpressionNode{T,D} <: AbstractNode{D}
@@ -73,25 +71,27 @@ You likely do not need to, but you could choose to override the following:
7371
- `with_type_parameters`
7472
7573
"""
76-
abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end
77-
78-
mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared}
79-
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
80-
constant::Bool # false if variable
81-
val::T # If is a constant, this stores the actual value
82-
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
83-
op::UInt8 # If operator, this is the index of the operator in the degree-specific operator enum
84-
children::SizedVector{D,GeneralNode{T,D,shared}} # Children nodes
85-
86-
#################
87-
## Constructors:
88-
#################
89-
GeneralNode{_T,_D,_shared}() where {_T,_D,_shared} = new{_T,_D,_shared}()
74+
abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end
75+
76+
for N in (:Node, :GraphNode)
77+
@eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D}
78+
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
79+
constant::Bool # false if variable
80+
val::T # If is a constant, this stores the actual value
81+
feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
82+
op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
83+
children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes
84+
85+
#################
86+
## Constructors:
87+
#################
88+
$N{_T,_D}() where {_T,_D} = new{_T,_D}()
89+
end
9090
end
9191

9292
#! format: off
9393
"""
94-
Node{T} <: AbstractExpressionNode{T,2}
94+
Node{T,D} <: AbstractExpressionNode{T,D}
9595
9696
Node defines a symbolic expression stored in a binary tree.
9797
A single `Node` instance is one "node" of this tree, and
@@ -113,7 +113,7 @@ nodes, you can evaluate or print a given expression.
113113
operator in `operators.binops`. In other words, this is an enum
114114
of the operators, and is dependent on the specific `OperatorEnum`
115115
object. Only defined if `degree >= 1`
116-
- `children::SizedArray{D,Node{T,D}}`: Children of the node. Only defined up to `degree`
116+
- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree`
117117
118118
# Constructors
119119
@@ -130,13 +130,13 @@ You may also construct nodes via the convenience operators generated by creating
130130
You may also choose to specify a default memory allocator for the node other than simply `Node{T}()`
131131
in the `allocator` keyword argument.
132132
"""
133-
const Node{T} = GeneralNode{T,2,false}
133+
Node
134134

135135

136136
"""
137-
GraphNode{T} <: AbstractExpressionNode{T}
137+
GraphNode{T,D} <: AbstractExpressionNode{T,D}
138138
139-
Exactly the same as [`Node{T}`](@ref), but with the assumption that some
139+
Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some
140140
nodes will be shared. All copies of this graph-like structure will
141141
be performed with this assumption, to preserve structure of the graph.
142142
@@ -164,7 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes
164164
are created simply by using the same node in multiple places
165165
when constructing or setting properties.
166166
"""
167-
const GraphNode{T} = GeneralNode{T,2,true}
167+
GraphNode
168168

169169
################################################################################
170170
#! format: on
@@ -177,30 +177,41 @@ function max_degree(::Type{N}) where {N<:AbstractExpressionNode}
177177
end
178178

179179
@unstable constructorof(::Type{<:Node}) = Node
180+
@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T
180181
@unstable constructorof(::Type{<:GraphNode}) = GraphNode
182+
@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T
181183

182-
with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T}
183-
with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}
184+
with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}
185+
with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}
184186

185-
default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}()
186-
default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}()
187+
default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}()
188+
default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}()
187189

188190
"""Trait declaring whether nodes share children or not."""
189191
preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false
190-
function preserve_sharing(
191-
::Union{Type{<:G},G}
192-
) where {shared,G<:GeneralNode{T,D,shared} where {T,D}}
193-
return shared
194-
end
192+
preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true
195193

196194
include("base.jl")
197195

198196
#! format: off
199197
@inline function (::Type{N})(
200-
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, children=nothing, allocator::F=default_allocator,
198+
::Type{T1}=Undefined; kws...
201199
) where {T1,N<:AbstractExpressionNode,F}
202-
validate_not_all_defaults(N, val, feature, op, children)
203-
return node_factory(N, T1, val, feature, op, children, allocator)
200+
end
201+
@inline function (::Type{N})(
202+
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator,
203+
) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F}
204+
_children = if l !== nothing && r === nothing
205+
@assert children === nothing
206+
(l,)
207+
elseif l !== nothing && r !== nothing
208+
@assert children === nothing
209+
(l, r)
210+
else
211+
children
212+
end
213+
validate_not_all_defaults(N, val, feature, op, _children)
214+
return node_factory(N, T1, val, feature, op, _children, allocator)
204215
end
205216
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode}
206217
return nothing
@@ -209,7 +220,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, children) where
209220
if val === nothing && feature === nothing && op === nothing && children === nothing
210221
error(
211222
"Encountered the call for $N() inside the generic constructor. "
212-
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
223+
* "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?"
213224
)
214225
end
215226
return nothing
@@ -219,7 +230,7 @@ end
219230
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F,
220231
) where {N,T1,T2,F}
221232
T = node_factory_type(N, T1, T2)
222-
n = allocator(N, T)
233+
n = allocator(N, T, D)
223234
n.degree = 0
224235
n.constant = true
225236
n.val = convert(T, val)
@@ -230,7 +241,7 @@ end
230241
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F,
231242
) where {N,T1,F}
232243
T = node_factory_type(N, T1, DEFAULT_NODE_TYPE)
233-
n = allocator(N, T)
244+
n = allocator(N, T, D)
234245
n.degree = 0
235246
n.constant = false
236247
n.feature = feature
@@ -239,19 +250,14 @@ end
239250
"""Create an operator node."""
240251
@inline function node_factory(
241252
::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F,
242-
) where {N,F,D2}
243-
D = max_degree(N)
244-
@assert D2 <= D
253+
) where {D,N<:AbstractExpressionNode{T where T,D},F,D2}
245254
T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion
246-
NT = with_type_parameters(N, T)
247-
n = allocator(N, T)
255+
NT = with_type_parameters(N, T, D)
256+
n = allocator(N, T, D)
248257
n.degree = D2
249258
n.op = op
250-
ar = SizedVector{D,NT}(undef)
251-
for i in 1:D2
252-
ar[i] = children[i]
253-
end
254-
n.children = ar
259+
n.children
260+
# map(Ref, children)
255261
return n
256262
end
257263

0 commit comments

Comments
 (0)