Skip to content

Commit 5860207

Browse files
authored
Merge pull request #383 from JuliaSymbolics/myb/unflatten
Unflatten * and + in toexpr
2 parents b0bf59a + 851aefd commit 5860207

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

src/code.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
77
SpawnFetch, Multithreaded
88

99
import ..SymbolicUtils
10-
import SymbolicUtils: @matchable, Sym, Term, istree, operation, arguments
10+
import SymbolicUtils: @matchable, Sym, Term, istree, operation, arguments,
11+
symtype
1112

1213
##== state management ==##
1314

@@ -98,6 +99,22 @@ toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))
9899

99100
function_to_expr(op, args, st) = nothing
100101

102+
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
103+
out = get(st.symbolify, O, nothing)
104+
out === nothing || return out
105+
args = map(Base.Fix2(toexpr, st), arguments(O))
106+
if length(args) >= 3 && symtype(O) <: Number
107+
x, xs = Iterators.peel(args)
108+
foldl(xs, init=x) do a, b
109+
Expr(:call, op, a, b)
110+
end
111+
else
112+
expr = Expr(:call, op)
113+
append!(expr.args, args)
114+
expr
115+
end
116+
end
117+
101118
function function_to_expr(::typeof(^), O, st)
102119
args = arguments(O)
103120
if length(args) == 2 && args[2] isa Real && args[2] < 0

test/code.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
1313
@test toexpr(Assignment(a, b)) == :(a = b)
1414
@test toexpr(a b) == :(a = b)
1515
@test toexpr(a+b) == :($(+)(a, b))
16+
@test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e))
17+
@test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e))
1618
@test toexpr(a+b) == :($(+)(a, b))
1719
@test toexpr(a^b) == :($(^)(a, b))
1820
@test toexpr(a^2) == :($(^)(a, 2))
1921
@test toexpr(a^-2) == :($(^)($(inv)(a), 2))
2022
@test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t)))
21-
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)(x(t), y(t), x($(+)(1, t))))
23+
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t))))
2224
s = LazyState()
2325
Code.union_symbolify!(s.symbolify, [x(t), y(t)])
24-
@test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)(var"x(t)", var"y(t)", x($(+)(1, t))))
26+
@test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)($(+)(var"x(t)", var"y(t)"), x($(+)(1, t))))
2527

2628
ex = :(let a = 3, b = $(+)(1,a)
2729
$(+)(a, b)
@@ -35,14 +37,14 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
3537

3638
test_repr(toexpr(Func([x(t), x],[b a+2, y(t) b], x(t)+x(t+1)+b+y(t))),
3739
:(function (var"x(t)", x; b = $(+)(2, a), var"y(t)" = b)
38-
$(+)(b, var"x(t)", var"y(t)", x($(+)(1, t)))
40+
$(+)($(+)($(+)(b, var"x(t)"), var"y(t)"), x($(+)(1, t)))
3941
end))
4042
test_repr(toexpr(Func([DestructuredArgs([x, x(t)], :state),
4143
DestructuredArgs((a, b), :params)], [],
4244
x(t+1) + x(t) + a + b)),
4345
:(function (state, params)
4446
let x = state[1], var"x(t)" = state[2], a = params[1], b = params[2]
45-
$(+)(a, b, var"x(t)", x($(+)(1, t)))
47+
$(+)($(+)($(+)(a, b), var"x(t)"), x($(+)(1, t)))
4648
end
4749
end))
4850

@@ -81,7 +83,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
8183
:(let foo = Any[3, 3, [1, 4]],
8284
var"x(t)" = foo[1], b = foo[2], c = foo[3],
8385
p = c[1], q = c[2]
84-
$(+)(a, b, c, var"x(t)")
86+
$(+)($(+)($(+)(a, b), c), var"x(t)")
8587
end))
8688

8789
test_repr(toexpr(Func([DestructuredArgs([a,b],c,inds=[:a, :b])], [],
@@ -119,7 +121,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
119121

120122
@test eval(toexpr(Let([a 1, b 2, arr @SLVector((:a, :b))(@SVector[1,2])],
121123
MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2])
122-
124+
123125
R1 = eval(toexpr(Let([a 1, b 2, arr @MVector([1,2])],MakeArray([a,b,a+b,a/b], arr))))
124126
@test R1 == (@MVector [1, 2, 3, 1/2]) && R1 isa MVector
125127

@@ -166,4 +168,3 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
166168
@test f(2) == 2
167169
end
168170
end
169-

0 commit comments

Comments
 (0)