Skip to content

Commit 2a0bd05

Browse files
committed
fix: various aspects of degree interface
1 parent 8707d24 commit 2a0bd05

File tree

10 files changed

+50
-38
lines changed

10 files changed

+50
-38
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import .NodeModule:
4747
constructorof,
4848
with_type_parameters,
4949
preserve_sharing,
50+
max_degree,
5051
leaf_copy,
5152
branch_copy,
5253
leaf_hash,

src/Node.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ end
219219
# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
220220
# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
221221

222-
function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T}
222+
function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
223223
return with_type_parameters(N, T)()
224224
end
225225

src/NodeUtils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,17 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D}
165165
end
166166
NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T))
167167

168+
@inline function Base.getproperty(n::NodeIndex, k::Symbol)
169+
if k == :l
170+
# TODO: Should a depwarn be raised here? Or too slow?
171+
return getfield(n, :children)[1][]
172+
elseif k == :r
173+
return getfield(n, :children)[2][]
174+
else
175+
return getfield(n, k)
176+
end
177+
end
178+
168179
# Sharing is never needed for NodeIndex,
169180
# as we trace over the node we are indexing on.
170181
preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false

src/base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function tree_mapreduce(
9494
f_on_shared::H=(result, is_shared) -> result,
9595
break_sharing::Val{BS}=Val(false),
9696
) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS}
97-
sharing = preserve_sharing(typeof(tree)) && !break_sharing
97+
sharing = preserve_sharing(typeof(tree)) && !BS
9898

9999
RT == Undefined &&
100100
sharing &&

test/test_base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ end
3232

3333
@testset "collect" begin
3434
ctree = copy(tree)
35-
@test typeof(first(collect(ctree))) == Node{Float64}
35+
@test typeof(first(collect(ctree))) <: Node{Float64}
3636
@test objectid(first(collect(ctree))) == objectid(ctree)
3737
@test objectid(first(collect(ctree))) == objectid(ctree)
3838
@test objectid(first(collect(ctree))) == objectid(ctree)
39-
@test typeof(collect(ctree)) == Vector{Node{Float64}}
39+
@test typeof(collect(ctree)) <: Vector{<:Node{Float64}}
4040
@test length(collect(ctree)) == 24
4141
@test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) 11.6
4242
end

test/test_custom_node_type.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
using DynamicExpressions
22
using Test
33

4-
mutable struct MyCustomNode{A,B} <: AbstractNode
4+
mutable struct MyCustomNode{A,B} <: AbstractNode{2}
55
degree::Int
66
val1::A
77
val2::B
8-
l::MyCustomNode{A,B}
9-
r::MyCustomNode{A,B}
8+
children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}}
109

1110
MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2)
12-
MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l)
13-
MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r)
11+
function MyCustomNode(val1, val2, l)
12+
return new{typeof(val1),typeof(val2)}(
13+
1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}())
14+
)
15+
end
16+
function MyCustomNode(val1, val2, l, r)
17+
return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r)))
18+
end
1419
end
1520

1621
node1 = MyCustomNode(1.0, 2)
@@ -24,7 +29,7 @@ node2 = MyCustomNode(1.5, 3, node1)
2429

2530
@test typeof(node2) == MyCustomNode{Float64,Int}
2631
@test node2.degree == 1
27-
@test node2.l.degree == 0
32+
@test node2.children[1][].degree == 0
2833
@test count_depth(node2) == 2
2934
@test count_nodes(node2) == 2
3035

@@ -37,14 +42,13 @@ node2 = MyCustomNode(1.5, 3, node1, node1)
3742
@test count(t -> t.degree == 0, node2) == 2
3843

3944
# If we have a bad definition, it should get caught with a helpful message
40-
mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T}
45+
mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2}
4146
degree::UInt8
4247
constant::Bool
4348
val::T
4449
feature::UInt16
4550
op::UInt8
46-
l::MyCustomNode2{T}
47-
r::MyCustomNode2{T}
51+
children::NTuple{2,Base.RefValue{MyCustomNode2{T}}}
4852
end
4953

5054
@test_throws ErrorException MyCustomNode2()

test/test_equality.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2)
4545

4646
f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1))
4747
f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1))
48-
@test typeof(f64_tree) == GraphNode{Float64}
49-
@test typeof(f32_tree) == GraphNode{Float32}
48+
@test typeof(f64_tree) <: GraphNode{Float64}
49+
@test typeof(f32_tree) <: GraphNode{Float32}
5050

5151
@test convert(GraphNode{Float64}, f32_tree) == f64_tree
5252

test/test_extra_node_fields.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,31 @@
22

33
using Test
44
using DynamicExpressions
5-
using DynamicExpressions: constructorof
5+
using DynamicExpressions: constructorof, max_degree
66

7-
mutable struct FrozenNode{T} <: AbstractExpressionNode{T}
7+
mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D}
88
degree::UInt8
99
constant::Bool
1010
val::T
1111
frozen::Bool # Extra field!
1212
feature::UInt16
1313
op::UInt8
14-
l::FrozenNode{T}
15-
r::FrozenNode{T}
14+
children::NTuple{D,Base.RefValue{FrozenNode{T,D}}}
1615

17-
function FrozenNode{_T}() where {_T}
18-
n = new{_T}()
16+
function FrozenNode{_T,_D}() where {_T,_D}
17+
n = new{_T,_D}()
1918
n.frozen = false
2019
return n
2120
end
2221
end
22+
function DynamicExpressions.constructorof(::Type{N}) where {N<:FrozenNode}
23+
return FrozenNode{T,max_degree(N)} where {T}
24+
end
25+
function DynamicExpressions.with_type_parameters(
26+
::Type{N}, ::Type{T}
27+
) where {T,N<:FrozenNode}
28+
return FrozenNode{T,max_degree(N)}
29+
end
2330
function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T}
2431
out = if t.constant
2532
constructorof(typeof(t))(; val=t.val)
@@ -56,7 +63,7 @@ function DynamicExpressions.leaf_equal(a::FrozenNode, b::FrozenNode)
5663
end
5764
end
5865

59-
n = let n = FrozenNode{Float64}()
66+
n = let n = FrozenNode{Float64,2}()
6067
n.degree = 0
6168
n.constant = true
6269
n.val = 0.0
@@ -92,5 +99,5 @@ ex = parse_expression(
9299

93100
@test string_tree(ex) == "x + sin(y + 2.1)"
94101
@test ex.tree.frozen == false
95-
@test ex.tree.r.frozen == true
96-
@test ex.tree.r.l.frozen == false
102+
@test ex.tree.children[2][].frozen == true
103+
@test ex.tree.children[2][].children[1][].frozen == false

test/test_graphs.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,6 @@ end
109109
:(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())),
110110
)
111111
end
112-
113-
@testset "@with_memoize" begin
114-
ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize(
115-
_convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}()
116-
)
117-
true_ex = quote
118-
_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())
119-
end
120-
121-
@test expr_eql(ex, true_ex)
122-
end
123112
end
124113

125114
@testset "Operations on graphs" begin
@@ -283,7 +272,7 @@ end
283272
x = GraphNode(Float32; feature=1)
284273
tree = x + 1.0
285274
@test tree.l === x
286-
@test typeof(tree) === GraphNode{Float32}
275+
@test typeof(tree) <: GraphNode{Float32}
287276

288277
# Detect error from Float32(1im)
289278
@test_throws InexactError x + 1im

test/test_parse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ end
108108
variable_names = ["x"],
109109
)
110110

111-
@test typeof(ex.tree) === Node{Any}
111+
@test typeof(ex.tree) <: Node{Any}
112112
@test typeof(ex.metadata.operators) <: GenericOperatorEnum
113113
s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex)
114114
@test s == "[1, 2, 3] * tan(cos(5.0 + x))"
@@ -184,7 +184,7 @@ end
184184
s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex)
185185
@test s == "(x * 2.5) - cos(y)"
186186
end
187-
@test contains(logged_out, "Node{Float32}")
187+
@test contains(logged_out, "Node{Float32")
188188
end
189189

190190
@testitem "Helpful errors for missing operator" begin

0 commit comments

Comments
 (0)