Skip to content

Commit 822d1b4

Browse files
authored
Merge pull request #100 from JuliaSymbolics/s/fold
propagate constants after substitute
2 parents c5bc4bd + 2a77262 commit 822d1b4

File tree

5 files changed

+20
-6
lines changed

5 files changed

+20
-6
lines changed

src/rule_dsl.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ macro rule(expr)
162162
lhs_term = makepattern(lhs, keys)
163163
unique!(keys)
164164
quote
165+
$(__source__)
165166
lhs_pattern = $(lhs_term)
166167
Rule($(QuoteNode(expr)),
167168
lhs_pattern,
168169
matcher(lhs_pattern),
169-
(__MATCHES__, __CTX__) -> $(makeconsequent(rhs)),
170+
(__MATCHES__, __CTX__) -> ($(__source__); $(makeconsequent(rhs))),
170171
rule_depth($lhs_term))
171172
end
172173
end

src/rulesets.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ const TIMES_RULES = RuleSet([
5050
])
5151

5252
const POW_RULES = RuleSet([
53-
@rule(^(*(~~x), ~y) => *(map(a->pow(a, ~y), ~~x)...))
54-
@rule((((~x)^(~p))^(~q)) => (~x)^((~p)*(~q)))
53+
@rule(^(*(~~x), ~y::isliteral(Integer)) => *(map(a->pow(a, ~y), ~~x)...))
54+
@rule((((~x)^(~p::isliteral(Integer)))^(~q::isliteral(Integer))) => (~x)^((~p)*(~q)))
5555
@rule(^(~x, ~z::_iszero) => 1)
5656
@rule(^(~x, ~z::_isone) => ~x)
5757
])

src/simplify.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,18 @@ substitute any subexpression that matches a key in `dict` with
5757
the corresponding value.
5858
"""
5959
function substitute(expr, dict)
60-
RuleSet([@rule ~x::(x->haskey(dict, x)) => dict[~x]])(expr)
60+
RuleSet([@rule ~x::(x->haskey(dict, x)) => dict[~x]])(expr) |> fold
61+
end
62+
63+
fold(x) = x
64+
function fold(t::Term)
65+
tt = map(fold, arguments(t))
66+
if !any(x->x isa Symbolic, tt)
67+
# evaluate it
68+
return operation(t)(tt...)
69+
else
70+
return Term{symtype(t)}(operation(t), tt)
71+
end
6172
end
6273

6374
### Predicates

test/basics.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ end
101101
@syms a b
102102
@test substitute(a, Dict(a=>1)) == 1
103103
@test isequal(substitute(sin(a+b), Dict(a=>1)), sin(1+b))
104-
@test substitute(a+b, Dict(a=>1, b=>3)) |> simplify == 4
104+
@test substitute(a+b, Dict(a=>1, b=>3)) == 4
105+
@test substitute(exp(a), Dict(a=>2)) exp(2)
105106
end
106107

107108
@testset "printing" begin

test/rulesets.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ end
2525
@eqtest simplify(1 * x * 2) == 2 * x
2626
@eqtest simplify(1 + x + 2) == 3 + x
2727
@eqtest simplify(b*b) == b^2 # tests merge_repeats
28-
@eqtest simplify((a*b)^c) == a^c * b^c
28+
@eqtest simplify((a*b)^2) == a^2 * b^2
29+
@eqtest simplify((a*b)^c) == (a*b)^c
2930

3031
@eqtest simplify(1x + 2x) == 3x
3132
@eqtest simplify(3x + 2x) == 5x

0 commit comments

Comments
 (0)