Skip to content

Commit a247c69

Browse files
committed
Add tests for tree map related functions
1 parent f21c4cf commit a247c69

File tree

3 files changed

+169
-1
lines changed

3 files changed

+169
-1
lines changed

src/base.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import Base:
1515
in,
1616
isempty,
1717
iterate,
18-
keys,
1918
length,
2019
map,
2120
map!,

test/test_base.jl

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
using DynamicExpressions
2+
using Random
3+
using Test
4+
5+
operators = OperatorEnum(;
6+
binary_operators=[+, -, *, /], unary_operators=[cos, sin], enable_autodiff=true
7+
);
8+
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)
9+
tree = cos(x1 * 3.2 - 5.8) * 0.2 - 0.5 * x2 * x3 * x3 + 0.9 / (x1 * x1 + 1);
10+
11+
@testset "all" begin
12+
ctree = copy(tree)
13+
@test all(t -> t.degree != -1, ctree)
14+
@test all(
15+
t -> t.degree != 0 || !t.constant || t.val in (3.2, 5.8, 0.2, 0.5, 0.9, 1.0), ctree
16+
)
17+
@test !all(
18+
t -> t.degree != 0 || !t.constant || t.val in (3.2, 5.8, 0.2, 0.9, 1.0), ctree
19+
)
20+
@test all(t -> t.degree != 0 || t.constant || t.feature in (1, 2, 3), ctree)
21+
@test !all(t -> t.degree != 0 || t.constant || t.feature in (2, 3), ctree)
22+
@test all(t -> t.degree != 1 || t.op == 1, ctree)
23+
@test !all(t -> t.degree != 1 || t.op == 2, ctree)
24+
end
25+
26+
@testset "any" begin
27+
ctree = copy(tree)
28+
@test any(t.degree == 2, ctree)
29+
@test any(_ -> true, ctree)
30+
@test any(t -> t.degree == 0 && t.constant && t.val == 3.2, ctree)
31+
@test !any(t -> t.degree == 0 && t.constant && t.val == 3.3, ctree)
32+
end
33+
34+
@testset "collect" begin
35+
ctree = copy(tree)
36+
@test first(collect(ctree)) == Node{Float64}
37+
@test objectid(first(collect(ctree))) == objectid(ctree)
38+
@test objectid(first(collect(ctree))) == objectid(ctree)
39+
@test objectid(first(collect(ctree))) == objectid(ctree)
40+
@test typeof(collect(ctree)) == Vector{Node{Float64}}
41+
@test length(collect(ctree)) == 24
42+
@test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) == 11.6
43+
end
44+
45+
@testset "count" begin
46+
ctree = copy(tree)
47+
@test count(_ -> true, ctree) == 24
48+
@test count(t -> t.degree == 0, ctree) == 12
49+
@test count(t -> t.degree == 1, ctree) == 1
50+
@test count(t -> t.degree == 2, ctree) == 11
51+
@test count(t -> t.degree == 0 && t.constant, ctree) == 6
52+
@test count(t -> t.degree == 0 && t.constant && t.val == 1, ctree) == 1
53+
end
54+
55+
@testset "filter" begin
56+
ctree = copy(tree)
57+
@test filter(_ -> true, ctree) == collect(ctree)
58+
@test length(filter(t -> t.degree == 0 && !t.constant)) == 6
59+
@test unique(filter(t -> t.degree == 0 && !t.constant)) == [x1, x2, x3]
60+
@test length(filter(t -> t.degree == 1, ctree)) == 1
61+
@test length(filter(t -> t.degree == 2, ctree)) == 11
62+
@test filter(==(x1), ctree) == [x1, x1, x1]
63+
end
64+
65+
@testset "foreach" begin
66+
ctree = copy(tree)
67+
counter = Ref(0)
68+
foreach(ctree) do t
69+
counter.x += 1
70+
end
71+
@test counter.x == 24
72+
foreach(ctree) do t
73+
if t.degree == 0 && t.constant
74+
t.val *= 2
75+
end
76+
end
77+
@test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2
78+
end
79+
80+
@testset "iterate" begin
81+
ctree = copy(tree)
82+
counter = Ref(0)
83+
for t in ctree
84+
counter.x += 1
85+
end
86+
@test counter.x == 24
87+
for t in ctree
88+
if t.degree == 0 && t.constant
89+
t.val *= 2
90+
end
91+
end
92+
@test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2
93+
94+
# iterate within iterate:
95+
counter = Ref(0)
96+
for t in ctree
97+
for t2 in t
98+
counter.x += 1
99+
end
100+
end
101+
@test counter.x == 104
102+
end
103+
104+
@testset "map" begin
105+
ctree = copy(tree)
106+
vals = map(t -> t.val, ctree)
107+
vals = [v for v in vals if v !== nothing]
108+
@test sum(vals) == 11.6
109+
@test sum(map(_ -> 1, ctree)) == 24
110+
@test sum(map(_ -> 2, ctree)) == 24 * 2
111+
@test sum(map(t -> t.degree == 1, ctree)) == 1
112+
@test length(unique(map(objectid, copy_node(tree)))) == 24
113+
@test length(unique(map(objectid, copy_node(tree; preserve_sharing=true)))) == 24 - 3
114+
map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree)
115+
@test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2
116+
end
117+
118+
@testset "in" begin
119+
ctree = copy(tree)
120+
@test x1 in ctree
121+
@test Node(Float64; val=1.0) ctree
122+
@test Node(Float32; val=1.0) ctree
123+
@test Node(Float64; val=1.1) ctree
124+
@test ctree.l ctree
125+
@test ctree.l * 2 ctree
126+
end
127+
128+
@testset "isempty" begin
129+
@test !isempty(tree)
130+
@test !isempty(Node(Float32; val=1))
131+
end
132+
133+
@testset "length" begin
134+
@test length(tree) == 24
135+
@test length(tree.l) == 16
136+
@test length(tree.r) == 24 - 16 - 1
137+
end
138+
139+
@testset "mapreduce" begin
140+
@test mapreduce(_ -> 1, +, tree) == 24
141+
@test mapreduce(_ -> 2, +, tree) == 48
142+
@test mapreduce(_ -> 1, *, tree) == 1
143+
@test mapreduce(_ -> 2, *, tree) == 2^24
144+
@test mapreduce(t -> t.degree, *, tree) == 0
145+
@test mapreduce(t -> t.degree + 1, *, tree) == 354294
146+
@test mapreduce(t -> t.degree, (l, r) -> (max(l, 1) * max(r, 1)), tree) == 2048
147+
@test mapreduce(+, tree) do t
148+
1
149+
end == 24
150+
end
151+
152+
@testset "sum" begin
153+
ctree = copy(tree)
154+
@test sum(t -> t.degree == 0 && t.constant ? t.val : 0.0, ctree) == 11.6
155+
@test sum(t -> t.degree == 0 && !t.constant ? t.feature : 0, ctree) ==
156+
3 * 1 + 1 * 2 + 2 * 3
157+
@test sum(t -> t.degree == 1 ? t.op : 0, ctree) == 1
158+
@test sum(t -> (t.degree == 0 && t.constant) ? t.val * 2 : 0.0, ctree) == 11.6 * 2
159+
for t in ctree
160+
if t.degree == 0 && t.constant
161+
t.val *= 1.5
162+
end
163+
end
164+
@test sum(t -> (t.degree == 0 && t.constant) ? t.val : 0.0, ctree) 11.6 * 1.5
165+
end

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,7 @@ end
9393
@safetestset "Test precompilation" begin
9494
include("test_precompilation.jl")
9595
end
96+
97+
@safetestset "Test Base" begin
98+
include("test_base.jl")
99+
end

0 commit comments

Comments
 (0)