|
1 | 1 | using SymbolicUtils, SymbolicUtils.Code, Test |
2 | | -using SymbolicUtils.Code: topological_sort |
| 2 | +using RuntimeGeneratedFunctions |
| 3 | + |
| 4 | +RuntimeGeneratedFunctions.init(@__MODULE__) |
3 | 5 |
|
4 | 6 | @testset "CSE" begin |
5 | 7 | @syms x |
6 | 8 | t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x)))) |
7 | 9 |
|
8 | 10 | @test t isa Let |
9 | | - @test length(t.pairs) == 4 |
10 | | - @test occursin(t.pairs[3].lhs, t.body) |
11 | | - @test occursin(t.pairs[4].lhs, t.body) |
| 11 | + @test length(t.pairs) == 5 |
| 12 | + @test occursin(t.pairs[3].lhs, t.pairs[5].rhs) |
| 13 | + @test occursin(t.pairs[4].lhs, t.pairs[5].rhs) |
12 | 14 | end |
13 | 15 |
|
14 | 16 | @testset "DAG CSE" begin |
|
26 | 28 | ab_node = sorted_nodes[1].lhs |
27 | 29 | @test isequal(term(^, ab_node, ab_node), sorted_nodes[2].rhs) |
28 | 30 | let_expr = cse(expr) |
29 | | - @test length(let_expr.pairs) == 1 |
| 31 | + @test length(let_expr.pairs) == 2 |
30 | 32 | @test isequal(let_expr.pairs[1].rhs, term(+, a, b)) |
31 | 33 | corresponding_sym = let_expr.pairs[1].lhs |
32 | | - @test isequal(let_expr.body, term(^, corresponding_sym, corresponding_sym)) |
| 34 | + @test isequal(let_expr.pairs[end].rhs, term(^, corresponding_sym, corresponding_sym)) |
33 | 35 |
|
34 | 36 | expr = a + b |
35 | 37 | sorted_nodes = topological_sort(expr) |
36 | 38 | @test length(sorted_nodes) == 1 |
37 | 39 | @test isequal(sorted_nodes[1].rhs, term(+, a, b)) |
38 | 40 | let_expr = cse(expr) |
39 | | - @test isempty(let_expr.pairs) |
40 | | - @test isequal(let_expr.body, term(+, a, b)) |
| 41 | + @test length(let_expr.pairs) == 1 |
| 42 | + @test isequal(let_expr.pairs[end].rhs, term(+, a, b)) |
41 | 43 |
|
42 | 44 | expr = a |
43 | 45 | sorted_nodes = topological_sort(expr) |
|
48 | 50 |
|
49 | 51 | # array symbolics |
50 | 52 | # https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739 |
51 | | - @syms c |
52 | | - function foo end |
53 | | - ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real) |
| 53 | + @syms a b c |
| 54 | + function foo(args...) |
| 55 | + return args |
| 56 | + end |
| 57 | + ex = term(foo, [a^2 + b^2, b^2 + c], (a^2 + b^2, b^2 + c), c; type = Real) |
54 | 58 | sorted_nodes = topological_sort(ex) |
55 | | - @test length(sorted_nodes) == 6 |
| 59 | + @test length(sorted_nodes) == 7 |
| 60 | + expr = quote |
| 61 | + a = 1 |
| 62 | + b = 2 |
| 63 | + c = 3 |
| 64 | + $(toexpr(cse(ex))) |
| 65 | + end |
| 66 | + vals = eval(expr) |
| 67 | + @test vals[1] == [1 + 4, 4 + 3] |
| 68 | + @test vals[2] == (1 + 4, 4 + 3) |
| 69 | + @test vals[3] == 3 |
| 70 | +end |
| 71 | + |
| 72 | +@testset "Expr" begin |
| 73 | + ex = :(a^2 + sin(a^2)) |
| 74 | + @test isequal(cse(ex).body, ex) |
| 75 | + ex = LiteralExpr(ex) |
| 76 | + @test isequal(cse(ex).body, ex) |
| 77 | +end |
| 78 | + |
| 79 | +@testset "Tuple" begin |
| 80 | + @syms a b |
| 81 | + ex = (a^2 + sin(a^2), sin(a^2) + b^2, b^2 + sin(b^2)) |
| 82 | + csex = cse(ex) |
| 83 | + i, j, k = findfirst.(isequal.(csex.body) .∘ Code.lhs, (csex.pairs,)) |
| 84 | + @test i !== nothing |
| 85 | + @test j !== nothing |
| 86 | + @test k !== nothing |
| 87 | + csex = Let(csex.pairs, MakeTuple(csex.body), false) |
| 88 | + expr = quote |
| 89 | + let a = 1, b = 2 |
| 90 | + $(toexpr(csex)) |
| 91 | + end |
| 92 | + end |
| 93 | + csex2 = cse(MakeTuple(collect(ex))) |
| 94 | + expr2 = quote |
| 95 | + let a = 1, b = 2 |
| 96 | + $(toexpr(csex2)) |
| 97 | + end |
| 98 | + end |
| 99 | + @test collect(eval(expr)) ≈ [1 + sin(1), sin(1) + 4, 4 + sin(4)] |
| 100 | + @test collect(eval(expr)) ≈ collect(eval(expr2)) |
| 101 | +end |
| 102 | + |
| 103 | +@testset "MakeArray, SetArray, MakeSparseArray, AtIndex" begin |
| 104 | + @syms a b c |
| 105 | + arr = [a^2 + sin(a * b) sin(a * b) + c^2 |
| 106 | + c^2 + sin(b * c) sin(b * c) + a^2] |
| 107 | + marr = MakeArray(arr, Array) |
| 108 | + sparr = sparse([1, 2, 3, 4], [1, 2, 3, 4], vec(arr)) |
| 109 | + msparr = MakeSparseArray(sparr) |
| 110 | + sarr = SetArray(false, :buffer, [[a^2 + c^2], AtIndex(3, arr), AtIndex(4, msparr)]) |
| 111 | + |
| 112 | + csex = cse(sarr) |
| 113 | + # test that simple array is CSEd |
| 114 | + @test findfirst(isequal(csex.body.elems[1][1]), Code.lhs.(csex.pairs)) !== nothing |
| 115 | + # test that `AtIndex` is CSEd |
| 116 | + i, j, k, l = findfirst.(isequal.(csex.body.elems[2].elem), (Code.lhs.(csex.pairs),)) |
| 117 | + @test i !== nothing |
| 118 | + @test j !== nothing |
| 119 | + @test k !== nothing |
| 120 | + @test l !== nothing |
| 121 | + # test that `MakeSpareArray` is CSEd, and re-uses the values from the `MakeArray` |
| 122 | + ii, jj, kk, ll = findfirst.(isequal.(findnz(csex.body.elems[3].elem.array)[3]), (Code.lhs.(csex.pairs),)) |
| 123 | + @test i == ii |
| 124 | + @test j == jj |
| 125 | + @test k == kk |
| 126 | + @test l == ll |
| 127 | + expr = quote |
| 128 | + let a = 1, b = 2, c = 3, buffer = Any[0, "A", 0, 0] |
| 129 | + $(toexpr(csex)) |
| 130 | + buffer |
| 131 | + end |
| 132 | + end |
| 133 | + val = eval(expr) |
| 134 | + @test val[1] == [10] |
| 135 | + @test val[2] == "A" |
| 136 | + result = [1 + sin(2) sin(2) + 9 |
| 137 | + 9 + sin(6) sin(6) + 1] |
| 138 | + @test val[3] == result |
| 139 | + @test val[4] == sparse([1, 2, 3, 4], [1, 2, 3, 4], vec(result)) |
| 140 | +end |
| 141 | + |
| 142 | +@testset "Let, Func, Assignment, DestructuredArgs" begin |
| 143 | + @syms a b c d::Array e f |
| 144 | + fn = Func([a, DestructuredArgs([b, c])], [], Let([Assignment(d, [a^2 + b^2, b^2 + c^2]), DestructuredArgs([e, f], term(broadcast, *, 2, d))], a^2 + b^2 + e + f)) |
| 145 | + csex = cse(fn) |
| 146 | + |
| 147 | + @test length(csex.body.pairs) == 9 |
| 148 | + sexprs = csex.body.pairs |
| 149 | + assignments = filter(x -> x isa Assignment, sexprs) |
| 150 | + @test sexprs[6].lhs === d |
| 151 | + # the array in the assignment should be CSEd |
| 152 | + i, j = findfirst.(isequal.(sexprs[6].rhs), (Code.lhs.(assignments),)) |
| 153 | + @test i !== nothing |
| 154 | + @test j !== nothing |
| 155 | + @test sexprs[8] isa DestructuredArgs |
| 156 | + @test isequal(sexprs[8].name, sexprs[7].lhs) |
| 157 | + |
| 158 | + rgf = @RuntimeGeneratedFunction(toexpr(csex)) |
| 159 | + trueval = let a = 1, |
| 160 | + b = 2, |
| 161 | + c = 3, |
| 162 | + tmp1 = b^2, |
| 163 | + tmp2 = a^2, |
| 164 | + tmp3 = tmp1 + tmp2, |
| 165 | + tmp4 = c^2, |
| 166 | + tmp5 = tmp1 + tmp4, |
| 167 | + d = [tmp3, tmp5], |
| 168 | + tmp6 = 2 .* d, |
| 169 | + e = tmp6[1], |
| 170 | + f = tmp6[2], |
| 171 | + tmp7 = f + tmp1 + tmp2 + e |
| 172 | + tmp7 |
| 173 | + end |
| 174 | + @test rgf(1, [2, 3]) == trueval |
| 175 | +end |
| 176 | + |
| 177 | +@testset "SpawnFetch" begin |
| 178 | + @syms a b c d |
| 179 | + fn = Func([c, d], [], c^2 + d^2 + sin(c^2)) |
| 180 | + ex = SpawnFetch{Multithreaded}([fn], [[a^2 + b^2, sin(a^2)]], only) |
| 181 | + csex = cse(ex) |
| 182 | + # arguments to the inner function are CSEd |
| 183 | + i, j = findfirst.(isequal.(csex.body.args[1]), (Code.lhs.(csex.pairs),)) |
| 184 | + @test i !== nothing |
| 185 | + @test j !== nothing |
| 186 | + innerkeys = Code.lhs.(csex.body.exprs[1].body.pairs) |
| 187 | + @test findfirst(isequal(csex.body.exprs[1].body.body), innerkeys) !== nothing |
| 188 | + |
| 189 | + expr = quote |
| 190 | + let a = 1, b = 2 |
| 191 | + $(toexpr(csex)) |
| 192 | + end |
| 193 | + end |
| 194 | + trueval = let a = 1, b = 2, c = a^2 + b^2, d = sin(a^2) |
| 195 | + c^2 + d^2 + sin(c^2) |
| 196 | + end |
| 197 | + @test eval(expr) == trueval |
| 198 | + innerfn = csex.body.exprs[1] |
| 199 | + |
| 200 | + # test inner function is CSEd independently |
| 201 | + rgf = @RuntimeGeneratedFunction(toexpr(innerfn)) |
| 202 | + @test rgf(5, sin(1)) == trueval |
56 | 203 | end |
0 commit comments