Skip to content

Commit c4e2ec2

Browse files
authored
Merge branch 'master' into s/cse
2 parents 9d0176d + 686ea2e commit c4e2ec2

File tree

8 files changed

+50
-18
lines changed

8 files changed

+50
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "0.16.0"
4+
version = "0.17.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ include("rewriters.jl")
3535
using .Rewriters
3636

3737
using Combinatorics: permutations, combinations
38-
export @rule, @acrule, RuleSet, @capture
38+
export @rule, @acrule, RuleSet
3939

4040
# Rule type and @rule macro
4141
include("rule.jl")

src/code.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))
100100

101101
function_to_expr(op, args, st) = nothing
102102

103+
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
104+
out = get(st.symbolify, O, nothing)
105+
out === nothing || return out
106+
args = map(Base.Fix2(toexpr, st), arguments(O))
107+
if length(args) >= 3 && symtype(O) <: Number
108+
x, xs = Iterators.peel(args)
109+
foldl(xs, init=x) do a, b
110+
Expr(:call, op, a, b)
111+
end
112+
else
113+
expr = Expr(:call, op)
114+
append!(expr.args, args)
115+
expr
116+
end
117+
end
118+
103119
function function_to_expr(::typeof(^), O, st)
104120
args = arguments(O)
105121
if length(args) == 2 && args[2] isa Real && args[2] < 0

src/simplify_rules.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ let
5656
@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z)
5757
]
5858

59-
TRIG_RULES = [
59+
TRIG_EXP_RULES = [
6060
@acrule(sin(~x)^2 + cos(~x)^2 => one(~x))
6161
@acrule(sin(~x)^2 + -1 => cos(~x)^2)
6262
@acrule(cos(~x)^2 + -1 => sin(~x)^2)
@@ -68,6 +68,9 @@ let
6868
@acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x))
6969
@acrule(cot(~x)^2 + 1 => csc(~x)^2)
7070
@acrule(csc(~x)^2 + -1 => cot(~x)^2)
71+
72+
@acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y))
73+
@rule(exp(~x)^(~y) => exp(~x * ~y))
7174
]
7275

7376
BOOLEAN_RULES = [
@@ -112,7 +115,7 @@ let
112115
rule_tree
113116
end
114117

115-
trig_simplifier(;kw...) = Chain(TRIG_RULES)
118+
trig_exp_simplifier(;kw...) = Chain(TRIG_EXP_RULES)
116119

117120
bool_simplifier() = Chain(BOOLEAN_RULES)
118121

@@ -123,10 +126,10 @@ let
123126
global serial_expand_simplifier
124127

125128
function default_simplifier(; kw...)
126-
IfElse(has_trig,
129+
IfElse(has_trig_exp,
127130
Postwalk(IfElse(x->symtype(x) <: Number,
128131
Chain((number_simplifier(),
129-
trig_simplifier())),
132+
trig_exp_simplifier())),
130133
If(x->symtype(x) <: Bool,
131134
bool_simplifier()))
132135
; kw...),

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ pow(x, y::Symbolic) = Base.:^(x,y)
7777
pow(x::Symbolic,y::Symbolic) = Base.:^(x,y)
7878

7979
# Simplification utilities
80-
function has_trig(term)
80+
function has_trig_exp(term)
8181
!istree(term) && return false
82-
fns = (sin, cos, tan, cot, sec, csc)
82+
fns = (sin, cos, tan, cot, sec, csc, exp)
8383
op = operation(term)
8484

85-
if Base.@nany 6 i->fns[i] === op
85+
if Base.@nany 7 i->fns[i] === op
8686
return true
8787
else
88-
return any(has_trig, arguments(term))
88+
return any(has_trig_exp, arguments(term))
8989
end
9090
end
9191

test/code.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
1313
@test toexpr(Assignment(a, b)) == :(a = b)
1414
@test toexpr(a b) == :(a = b)
1515
@test toexpr(a+b) == :($(+)(a, b))
16+
@test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e))
17+
@test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e))
1618
@test toexpr(a+b) == :($(+)(a, b))
1719
@test toexpr(a^b) == :($(^)(a, b))
1820
@test toexpr(a^2) == :($(^)(a, 2))
1921
@test toexpr(a^-2) == :($(^)($(inv)(a), 2))
2022
@test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t)))
21-
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)(x(t), y(t), x($(+)(1, t))))
23+
@test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t))))
2224
s = LazyState()
2325
Code.union_symbolify!(s.symbolify, [x(t), y(t)])
24-
@test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)(var"x(t)", var"y(t)", x($(+)(1, t))))
26+
@test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)($(+)(var"x(t)", var"y(t)"), x($(+)(1, t))))
2527

2628
ex = :(let a = 3, b = $(+)(1,a)
2729
$(+)(a, b)
@@ -35,14 +37,14 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
3537

3638
test_repr(toexpr(Func([x(t), x],[b a+2, y(t) b], x(t)+x(t+1)+b+y(t))),
3739
:(function (var"x(t)", x; b = $(+)(2, a), var"y(t)" = b)
38-
$(+)(b, var"x(t)", var"y(t)", x($(+)(1, t)))
40+
$(+)($(+)($(+)(b, var"x(t)"), var"y(t)"), x($(+)(1, t)))
3941
end))
4042
test_repr(toexpr(Func([DestructuredArgs([x, x(t)], :state),
4143
DestructuredArgs((a, b), :params)], [],
4244
x(t+1) + x(t) + a + b)),
4345
:(function (state, params)
4446
let x = state[1], var"x(t)" = state[2], a = params[1], b = params[2]
45-
$(+)(a, b, var"x(t)", x($(+)(1, t)))
47+
$(+)($(+)($(+)(a, b), var"x(t)"), x($(+)(1, t)))
4648
end
4749
end))
4850

@@ -81,7 +83,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
8183
:(let foo = Any[3, 3, [1, 4]],
8284
var"x(t)" = foo[1], b = foo[2], c = foo[3],
8385
p = c[1], q = c[2]
84-
$(+)(a, b, c, var"x(t)")
86+
$(+)($(+)($(+)(a, b), c), var"x(t)")
8587
end))
8688

8789
test_repr(toexpr(Func([DestructuredArgs([a,b],c,inds=[:a, :b])], [],
@@ -119,7 +121,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
119121

120122
@test eval(toexpr(Let([a 1, b 2, arr @SLVector((:a, :b))(@SVector[1,2])],
121123
MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2])
122-
124+
123125
R1 = eval(toexpr(Let([a 1, b 2, arr @MVector([1,2])],MakeArray([a,b,a+b,a/b], arr))))
124126
@test R1 == (@MVector [1, 2, 3, 1/2]) && R1 isa MVector
125127

@@ -166,4 +168,3 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
166168
@test f(2) == 2
167169
end
168170
end
169-

test/rewrite.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ end
4343
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])
4444
end
4545

46+
using SymbolicUtils: @capture
47+
4648
@testset "Capture form" begin
49+
4750
ex = a^a
4851

4952
#note that @test inserts a soft local scope (try-catch) that would gobble
@@ -72,4 +75,4 @@ end
7275

7376
@eqtest f(b^b) == b
7477
@test f(b+b) == nothing
75-
end
78+
end

test/rulesets.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ end
8888
@eqtest simplify(1 + y + cot(x)^2) == csc(x)^2 + y
8989
end
9090

91+
@testset "Exponentials" begin
92+
@syms a::Real b::Real
93+
@eqtest simplify(exp(a)*exp(b)) == simplify(exp(a+b))
94+
@eqtest simplify(exp(a)*exp(a)) == simplify(exp(2a))
95+
@test simplify(exp(a)*exp(-a)) == 1
96+
@eqtest simplify(exp(a)^2) == simplify(exp(2a))
97+
@eqtest simplify(exp(a) * a * exp(b)) == simplify(a*exp(a+b))
98+
end
99+
91100
@testset "Depth" begin
92101
@syms x
93102
R = Rewriters.Postwalk(Rewriters.Chain([@rule(sin(~x) => cos(~x)),

0 commit comments

Comments
 (0)