Skip to content

Commit 3ed6b41

Browse files
committed
fix: various aspects of degree interface
1 parent ab7c65c commit 3ed6b41

File tree

3 files changed

+86
-51
lines changed

3 files changed

+86
-51
lines changed

src/Node.jl

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ for N in (:Node, :GraphNode)
8585
#################
8686
## Constructors:
8787
#################
88-
$N{_T,_D}() where {_T,_D} = new{_T,_D}()
88+
$N{_T,_D}() where {_T,_D} = new{_T,_D::Int}()
8989
end
9090
end
9191

@@ -166,26 +166,62 @@ when constructing or setting properties.
166166
"""
167167
GraphNode
168168

169+
@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol)
170+
if k == :l
171+
# TODO: Should a depwarn be raised here? Or too slow?
172+
return getfield(n, :children)[1][]
173+
elseif k == :r
174+
return getfield(n, :children)[2][]
175+
else
176+
return getfield(n, k)
177+
end
178+
end
179+
@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v)
180+
if k == :l
181+
getfield(n, :children)[1][] = v
182+
elseif k == :r
183+
getfield(n, :children)[2][] = v
184+
elseif k == :degree
185+
setfield!(n, :degree, convert(UInt8, v))
186+
elseif k == :constant
187+
setfield!(n, :constant, convert(Bool, v))
188+
elseif k == :feature
189+
setfield!(n, :feature, convert(UInt16, v))
190+
elseif k == :op
191+
setfield!(n, :op, convert(UInt8, v))
192+
elseif k == :val
193+
setfield!(n, :val, convert(eltype(n), v))
194+
elseif k == :children
195+
setfield!(n, :children, v)
196+
else
197+
error("Invalid property: $k")
198+
end
199+
end
200+
169201
################################################################################
170202
#! format: on
171203

172204
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
173205
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
174206

175-
function max_degree(::Type{N}) where {N<:AbstractExpressionNode}
176-
return (N isa UnionAll ? N.body : N).parameters[2]
177-
end
207+
max_degree(::Type{<:AbstractNode}) = 2 # Default
208+
max_degree(::Type{<:AbstractNode{D}}) where {D} = D
209+
210+
@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T}
211+
@unstable constructorof(::Type{N}) where {N<:GraphNode} =
212+
GraphNode{T,max_degree(N)} where {T}
178213

179-
@unstable constructorof(::Type{<:Node}) = Node
180-
@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T
181-
@unstable constructorof(::Type{<:GraphNode}) = GraphNode
182-
@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T
214+
with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)}
215+
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T}
216+
return GraphNode{T,max_degree(N)}
217+
end
183218

184-
with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}
185-
with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}
219+
# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
220+
# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
186221

187-
default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}()
188-
default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}()
222+
function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T}
223+
return with_type_parameters(N, T)()
224+
end
189225

190226
"""Trait declaring whether nodes share children or not."""
191227
preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false
@@ -194,13 +230,9 @@ preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true
194230
include("base.jl")
195231

196232
#! format: off
197-
@inline function (::Type{N})(
198-
::Type{T1}=Undefined; kws...
199-
) where {T1,N<:AbstractExpressionNode,F}
200-
end
201233
@inline function (::Type{N})(
202234
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator,
203-
) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F}
235+
) where {T1,N<:AbstractExpressionNode{T} where T,F}
204236
_children = if l !== nothing && r === nothing
205237
@assert children === nothing
206238
(l,)
@@ -230,7 +262,7 @@ end
230262
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F,
231263
) where {N,T1,T2,F}
232264
T = node_factory_type(N, T1, T2)
233-
n = allocator(N, T, D)
265+
n = allocator(N, T)
234266
n.degree = 0
235267
n.constant = true
236268
n.val = convert(T, val)
@@ -241,23 +273,24 @@ end
241273
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F,
242274
) where {N,T1,F}
243275
T = node_factory_type(N, T1, DEFAULT_NODE_TYPE)
244-
n = allocator(N, T, D)
276+
n = allocator(N, T)
245277
n.degree = 0
246278
n.constant = false
247279
n.feature = feature
248280
return n
249281
end
250282
"""Create an operator node."""
251283
@inline function node_factory(
252-
::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F,
253-
) where {D,N<:AbstractExpressionNode{T where T,D},F,D2}
284+
::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F,
285+
) where {N<:AbstractExpressionNode,F}
254286
T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion
255-
NT = with_type_parameters(N, T, D)
256-
n = allocator(N, T, D)
287+
D2 = length(children)
288+
@assert D2 <= max_degree(N)
289+
NT = with_type_parameters(N, T)
290+
n = allocator(N, T)
257291
n.degree = D2
258292
n.op = op
259-
n.children
260-
# map(Ref, children)
293+
n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N)))
261294
return n
262295
end
263296

@@ -298,14 +331,14 @@ function (::Type{N})(
298331
return N(; feature=i)
299332
end
300333

301-
function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2}
302-
return Node{promote_type(T1, T2)}
334+
function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
335+
return Node{promote_type(T1, T2),D}
303336
end
304-
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2}
305-
return GraphNode{promote_type(T1, T2)}
337+
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
338+
return GraphNode{promote_type(T1, T2),D}
306339
end
307-
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2}
308-
return GraphNode{promote_type(T1, T2)}
340+
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D}
341+
return GraphNode{promote_type(T1, T2),D}
309342
end
310343

311344
# TODO: Verify using this helps with garbage collection

src/NodeUtils.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
module NodeUtilsModule
22

3-
using StaticArrays: MVector
43
import Compat: Returns
54
import ..NodeModule:
65
AbstractNode,
76
AbstractExpressionNode,
8-
GeneralNode,
97
Node,
108
preserve_sharing,
119
constructorof,
@@ -145,38 +143,43 @@ end
145143
## Assign index to nodes of a tree
146144
# This will mirror a Node struct, rather
147145
# than adding a new attribute to Node.
148-
struct NodeIndex{T,D} <: AbstractNode{D,false}
146+
struct NodeIndex{T,D} <: AbstractNode{D}
149147
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
150148
val::T # If is a constant, this stores the actual value
151149
# ------------------- (possibly undefined below)
152-
children::MVector{D,NodeIndex{T,D}}
153-
154-
NodeIndex(::Type{_T}, ::Type{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T))
155-
NodeIndex(::Type{_T}, ::Type{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val))
156-
function NodeIndex(::Type{_T}, ::Type{_D}, children::Vararg{Any,_D2}) where {_T,_D,_D2}
157-
_children = MVector{_D,NodeIndex{_T,_D}}(undef)
158-
_children[begin:_D2] = children
150+
children::NTuple{D,Base.RefValue{NodeIndex{T,D}}}
151+
152+
NodeIndex(::Type{_T}, ::Val{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T))
153+
NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val))
154+
function NodeIndex(
155+
::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2}
156+
) where {_T,_D,_D2}
157+
_children = ntuple(
158+
i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D)
159+
)
159160
return new{_T,_D}(1, zero(_T), _children)
160161
end
161162
end
162163
# Sharing is never needed for NodeIndex,
163164
# as we trace over the node we are indexing on.
164165
preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false
165166

166-
function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T}
167+
function index_constant_nodes(
168+
tree::AbstractExpressionNode{Ti,D} where {Ti}, ::Type{T}=UInt16
169+
) where {D,T}
167170
# Essentially we copy the tree, replacing the values
168171
# with indices
169172
constant_index = Ref(T(0))
170173
return tree_mapreduce(
171174
t -> if t.constant
172-
NodeIndex(T, (constant_index[] += T(1)))
175+
NodeIndex(T, Val(D), (constant_index[] += T(1)))
173176
else
174-
NodeIndex(T)
177+
NodeIndex(T, Val(D))
175178
end,
176179
t -> nothing,
177-
(_, c...) -> NodeIndex(T, c...),
180+
(_, c...) -> NodeIndex(T, Val(D), c...),
178181
tree,
179-
NodeIndex{T};
182+
NodeIndex{T,D};
180183
)
181184
end
182185

src/ParametricExpression.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module ParametricExpressionModule
22

33
using DispatchDoctor: @stable, @unstable
4-
using StaticArrays: MVector
54
using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk
65

76
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
@@ -35,7 +34,7 @@ import ..ValueInterfaceModule:
3534
count_scalar_constants, pack_scalar_constants!, unpack_scalar_constants
3635

3736
"""A type of expression node that also stores a parameter index"""
38-
mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared}
37+
mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D}
3938
degree::UInt8
4039
constant::Bool # if true => constant; if false, then check `is_parameter`
4140
val::T
@@ -45,10 +44,10 @@ mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared}
4544
parameter::UInt16 # Stores index of per-class parameter
4645

4746
op::UInt8
48-
children::MVector{D,ParametricNode{T,D}} # Children nodes
47+
children::NTuple{D,Base.RefValue{ParametricNode{T,D}}} # Children nodes
4948

50-
function ParametricNode{_T,_D,_shared}() where {_T,_D,_shared}
51-
n = new{_T,_D,_shared}()
49+
function ParametricNode{_T,_D}() where {_T,_D}
50+
n = new{_T,_D}()
5251
n.is_parameter = false
5352
n.parameter = UInt16(0)
5453
return n

0 commit comments

Comments
 (0)