Skip to content

Commit 9240c5c

Browse files
test: test new CSE
1 parent d20903f commit 9240c5c

File tree

1 file changed

+159
-12
lines changed

1 file changed

+159
-12
lines changed

test/cse.jl

Lines changed: 159 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
using SymbolicUtils, SymbolicUtils.Code, Test
2-
using SymbolicUtils.Code: topological_sort
2+
using RuntimeGeneratedFunctions
3+
4+
RuntimeGeneratedFunctions.init(@__MODULE__)
35

46
@testset "CSE" begin
57
@syms x
68
t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x))))
79

810
@test t isa Let
9-
@test length(t.pairs) == 4
10-
@test occursin(t.pairs[3].lhs, t.body)
11-
@test occursin(t.pairs[4].lhs, t.body)
11+
@test length(t.pairs) == 5
12+
@test occursin(t.pairs[3].lhs, t.pairs[5].rhs)
13+
@test occursin(t.pairs[4].lhs, t.pairs[5].rhs)
1214
end
1315

1416
@testset "DAG CSE" begin
@@ -26,18 +28,18 @@ end
2628
ab_node = sorted_nodes[1].lhs
2729
@test isequal(term(^, ab_node, ab_node), sorted_nodes[2].rhs)
2830
let_expr = cse(expr)
29-
@test length(let_expr.pairs) == 1
31+
@test length(let_expr.pairs) == 2
3032
@test isequal(let_expr.pairs[1].rhs, term(+, a, b))
3133
corresponding_sym = let_expr.pairs[1].lhs
32-
@test isequal(let_expr.body, term(^, corresponding_sym, corresponding_sym))
34+
@test isequal(let_expr.pairs[end].rhs, term(^, corresponding_sym, corresponding_sym))
3335

3436
expr = a + b
3537
sorted_nodes = topological_sort(expr)
3638
@test length(sorted_nodes) == 1
3739
@test isequal(sorted_nodes[1].rhs, term(+, a, b))
3840
let_expr = cse(expr)
39-
@test isempty(let_expr.pairs)
40-
@test isequal(let_expr.body, term(+, a, b))
41+
@test length(let_expr.pairs) == 1
42+
@test isequal(let_expr.pairs[end].rhs, term(+, a, b))
4143

4244
expr = a
4345
sorted_nodes = topological_sort(expr)
@@ -48,9 +50,154 @@ end
4850

4951
# array symbolics
5052
# https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739
51-
@syms c
52-
function foo end
53-
ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real)
53+
@syms a b c
54+
function foo(args...)
55+
return args
56+
end
57+
ex = term(foo, [a^2 + b^2, b^2 + c], (a^2 + b^2, b^2 + c), c; type = Real)
5458
sorted_nodes = topological_sort(ex)
55-
@test length(sorted_nodes) == 6
59+
@test length(sorted_nodes) == 7
60+
expr = quote
61+
a = 1
62+
b = 2
63+
c = 3
64+
$(toexpr(cse(ex)))
65+
end
66+
vals = eval(expr)
67+
@test vals[1] == [1 + 4, 4 + 3]
68+
@test vals[2] == (1 + 4, 4 + 3)
69+
@test vals[3] == 3
70+
end
71+
72+
@testset "Expr" begin
73+
ex = :(a^2 + sin(a^2))
74+
@test isequal(cse(ex).body, ex)
75+
ex = LiteralExpr(ex)
76+
@test isequal(cse(ex).body, ex)
77+
end
78+
79+
@testset "Tuple" begin
80+
@syms a b
81+
ex = (a^2 + sin(a^2), sin(a^2) + b^2, b^2 + sin(b^2))
82+
csex = cse(ex)
83+
i, j, k = findfirst.(isequal.(csex.body) .∘ Code.lhs, (csex.pairs,))
84+
@test i !== nothing
85+
@test j !== nothing
86+
@test k !== nothing
87+
csex = Let(csex.pairs, MakeTuple(csex.body), false)
88+
expr = quote
89+
let a = 1, b = 2
90+
$(toexpr(csex))
91+
end
92+
end
93+
csex2 = cse(MakeTuple(collect(ex)))
94+
expr2 = quote
95+
let a = 1, b = 2
96+
$(toexpr(csex2))
97+
end
98+
end
99+
@test collect(eval(expr)) [1 + sin(1), sin(1) + 4, 4 + sin(4)]
100+
@test collect(eval(expr)) collect(eval(expr2))
101+
end
102+
103+
@testset "MakeArray, SetArray, MakeSparseArray, AtIndex" begin
104+
@syms a b c
105+
arr = [a^2 + sin(a * b) sin(a * b) + c^2
106+
c^2 + sin(b * c) sin(b * c) + a^2]
107+
marr = MakeArray(arr, Array)
108+
sparr = sparse([1, 2, 3, 4], [1, 2, 3, 4], vec(arr))
109+
msparr = MakeSparseArray(sparr)
110+
sarr = SetArray(false, :buffer, [[a^2 + c^2], AtIndex(3, arr), AtIndex(4, msparr)])
111+
112+
csex = cse(sarr)
113+
# test that simple array is CSEd
114+
@test findfirst(isequal(csex.body.elems[1][1]), Code.lhs.(csex.pairs)) !== nothing
115+
# test that `AtIndex` is CSEd
116+
i, j, k, l = findfirst.(isequal.(csex.body.elems[2].elem), (Code.lhs.(csex.pairs),))
117+
@test i !== nothing
118+
@test j !== nothing
119+
@test k !== nothing
120+
@test l !== nothing
121+
# test that `MakeSpareArray` is CSEd, and re-uses the values from the `MakeArray`
122+
ii, jj, kk, ll = findfirst.(isequal.(findnz(csex.body.elems[3].elem.array)[3]), (Code.lhs.(csex.pairs),))
123+
@test i == ii
124+
@test j == jj
125+
@test k == kk
126+
@test l == ll
127+
expr = quote
128+
let a = 1, b = 2, c = 3, buffer = Any[0, "A", 0, 0]
129+
$(toexpr(csex))
130+
buffer
131+
end
132+
end
133+
val = eval(expr)
134+
@test val[1] == [10]
135+
@test val[2] == "A"
136+
result = [1 + sin(2) sin(2) + 9
137+
9 + sin(6) sin(6) + 1]
138+
@test val[3] == result
139+
@test val[4] == sparse([1, 2, 3, 4], [1, 2, 3, 4], vec(result))
140+
end
141+
142+
@testset "Let, Func, Assignment, DestructuredArgs" begin
143+
@syms a b c d::Array e f
144+
fn = Func([a, DestructuredArgs([b, c])], [], Let([Assignment(d, [a^2 + b^2, b^2 + c^2]), DestructuredArgs([e, f], term(broadcast, *, 2, d))], a^2 + b^2 + e + f))
145+
csex = cse(fn)
146+
147+
@test length(csex.body.pairs) == 9
148+
sexprs = csex.body.pairs
149+
assignments = filter(x -> x isa Assignment, sexprs)
150+
@test sexprs[6].lhs === d
151+
# the array in the assignment should be CSEd
152+
i, j = findfirst.(isequal.(sexprs[6].rhs), (Code.lhs.(assignments),))
153+
@test i !== nothing
154+
@test j !== nothing
155+
@test sexprs[8] isa DestructuredArgs
156+
@test isequal(sexprs[8].name, sexprs[7].lhs)
157+
158+
rgf = @RuntimeGeneratedFunction(toexpr(csex))
159+
trueval = let a = 1,
160+
b = 2,
161+
c = 3,
162+
tmp1 = b^2,
163+
tmp2 = a^2,
164+
tmp3 = tmp1 + tmp2,
165+
tmp4 = c^2,
166+
tmp5 = tmp1 + tmp4,
167+
d = [tmp3, tmp5],
168+
tmp6 = 2 .* d,
169+
e = tmp6[1],
170+
f = tmp6[2],
171+
tmp7 = f + tmp1 + tmp2 + e
172+
tmp7
173+
end
174+
@test rgf(1, [2, 3]) == trueval
175+
end
176+
177+
@testset "SpawnFetch" begin
178+
@syms a b c d
179+
fn = Func([c, d], [], c^2 + d^2 + sin(c^2))
180+
ex = SpawnFetch{Multithreaded}([fn], [[a^2 + b^2, sin(a^2)]], only)
181+
csex = cse(ex)
182+
# arguments to the inner function are CSEd
183+
i, j = findfirst.(isequal.(csex.body.args[1]), (Code.lhs.(csex.pairs),))
184+
@test i !== nothing
185+
@test j !== nothing
186+
innerkeys = Code.lhs.(csex.body.exprs[1].body.pairs)
187+
@test findfirst(isequal(csex.body.exprs[1].body.body), innerkeys) !== nothing
188+
189+
expr = quote
190+
let a = 1, b = 2
191+
$(toexpr(csex))
192+
end
193+
end
194+
trueval = let a = 1, b = 2, c = a^2 + b^2, d = sin(a^2)
195+
c^2 + d^2 + sin(c^2)
196+
end
197+
@test eval(expr) == trueval
198+
innerfn = csex.body.exprs[1]
199+
200+
# test inner function is CSEd independently
201+
rgf = @RuntimeGeneratedFunction(toexpr(innerfn))
202+
@test rgf(5, sin(1)) == trueval
56203
end

0 commit comments

Comments
 (0)