Skip to content

Commit b0bf59a

Browse files
authored
Merge pull request #382 from kllrak/kllrak/exprules
Add rules to simplify exponentials.
2 parents 51b335b + a321227 commit b0bf59a

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

src/simplify_rules.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ let
5656
@rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z)
5757
]
5858

59-
TRIG_RULES = [
59+
TRIG_EXP_RULES = [
6060
@acrule(sin(~x)^2 + cos(~x)^2 => one(~x))
6161
@acrule(sin(~x)^2 + -1 => cos(~x)^2)
6262
@acrule(cos(~x)^2 + -1 => sin(~x)^2)
@@ -68,6 +68,9 @@ let
6868
@acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x))
6969
@acrule(cot(~x)^2 + 1 => csc(~x)^2)
7070
@acrule(csc(~x)^2 + -1 => cot(~x)^2)
71+
72+
@acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y))
73+
@rule(exp(~x)^(~y) => exp(~x * ~y))
7174
]
7275

7376
BOOLEAN_RULES = [
@@ -112,7 +115,7 @@ let
112115
rule_tree
113116
end
114117

115-
trig_simplifier(;kw...) = Chain(TRIG_RULES)
118+
trig_exp_simplifier(;kw...) = Chain(TRIG_EXP_RULES)
116119

117120
bool_simplifier() = Chain(BOOLEAN_RULES)
118121

@@ -123,10 +126,10 @@ let
123126
global serial_expand_simplifier
124127

125128
function default_simplifier(; kw...)
126-
IfElse(has_trig,
129+
IfElse(has_trig_exp,
127130
Postwalk(IfElse(x->symtype(x) <: Number,
128131
Chain((number_simplifier(),
129-
trig_simplifier())),
132+
trig_exp_simplifier())),
130133
If(x->symtype(x) <: Bool,
131134
bool_simplifier()))
132135
; kw...),

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ pow(x, y::Symbolic) = Base.:^(x,y)
7777
pow(x::Symbolic,y::Symbolic) = Base.:^(x,y)
7878

7979
# Simplification utilities
80-
function has_trig(term)
80+
function has_trig_exp(term)
8181
!istree(term) && return false
82-
fns = (sin, cos, tan, cot, sec, csc)
82+
fns = (sin, cos, tan, cot, sec, csc, exp)
8383
op = operation(term)
8484

85-
if Base.@nany 6 i->fns[i] === op
85+
if Base.@nany 7 i->fns[i] === op
8686
return true
8787
else
88-
return any(has_trig, arguments(term))
88+
return any(has_trig_exp, arguments(term))
8989
end
9090
end
9191

test/rulesets.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ end
8888
@eqtest simplify(1 + y + cot(x)^2) == csc(x)^2 + y
8989
end
9090

91+
@testset "Exponentials" begin
92+
@syms a::Real b::Real
93+
@eqtest simplify(exp(a)*exp(b)) == simplify(exp(a+b))
94+
@eqtest simplify(exp(a)*exp(a)) == simplify(exp(2a))
95+
@test simplify(exp(a)*exp(-a)) == 1
96+
@eqtest simplify(exp(a)^2) == simplify(exp(2a))
97+
@eqtest simplify(exp(a) * a * exp(b)) == simplify(a*exp(a+b))
98+
end
99+
91100
@testset "Depth" begin
92101
@syms x
93102
R = Rewriters.Postwalk(Rewriters.Chain([@rule(sin(~x) => cos(~x)),

0 commit comments

Comments
 (0)