Skip to content

Commit 5220a81

Browse files
authored
Merge pull request #422 from JuliaSymbolics/s/simplify2
Fixes to simplify_fractions, call it in simplify by default
2 parents 62a8c27 + 144a0ae commit 5220a81

File tree

5 files changed

+72
-43
lines changed

5 files changed

+72
-43
lines changed

src/SymbolicUtils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ include("ordering.jl")
5454
include("simplify_rules.jl")
5555

5656
# API = simplify + substitute
57-
export simplify, substitute
58-
include("api.jl")
57+
export simplify
58+
include("simplify.jl")
59+
60+
export substitute
61+
include("substitute.jl")
5962

6063
# EGraph rewriting
6164
include("egraph.jl")

src/polyform.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ add_divs(x::Div, y) = (x.num + y * x.den) / x.den
274274
add_divs(x, y::Div) = (x * y.den + y.num) / y.den
275275
add_divs(x, y) = x + y
276276

277+
function frac_similarterm(x, f, args; kw...)
278+
if f in (*, /, \, +, -, ^)
279+
f(args...)
280+
else
281+
similarterm(x, f, args; kw...)
282+
end
283+
end
284+
277285
"""
278286
simplify_fractions(x; polyform=false)
279287
@@ -290,7 +298,9 @@ function simplify_fractions(x; polyform=false)
290298

291299
sdiv(a) = a isa Div ? simplify_div(a) : a
292300

293-
expr = Postwalk(sdiv quick_cancel)(Postwalk(add_with_div)(x))
301+
expr = Postwalk(sdiv quick_cancel,
302+
similarterm=frac_similarterm)(Postwalk(add_with_div,
303+
similarterm=frac_similarterm)(x))
294304

295305
polyform ? expr : unpolyize(expr)
296306
end

src/simplify.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
"""
3+
```julia
4+
simplify(x; expand=false,
5+
threaded=false,
6+
thread_subtree_cutoff=100,
7+
nonzero_denominators=true,
8+
rewriter=nothing)
9+
```
10+
11+
Simplify an expression (`x`) by applying `rewriter` until there are no changes.
12+
`expand=true` applies [`expand`](/api/#expand) in the beginning of each fixpoint iteration.
13+
14+
By default, simplify will assume denominators are not zero and allow cancellation in fractions.
15+
Pass `simplify_fractions=false` to prevent this.
16+
"""
17+
function simplify(x;
18+
expand=false,
19+
polynorm=nothing,
20+
threaded=false,
21+
simplify_fractions=true,
22+
thread_subtree_cutoff=100,
23+
rewriter=nothing)
24+
if polynorm !== nothing
25+
Base.depwarn("simplify(..; polynorm=$polynorm) is deprecated, use simplify(..; expand=$polynorm) instead",
26+
:simplify)
27+
end
28+
29+
30+
f = if rewriter === nothing
31+
if threaded
32+
threaded_simplifier(thread_subtree_cutoff)
33+
elseif expand
34+
serial_expand_simplifier
35+
else
36+
serial_simplifier
37+
end
38+
else
39+
Fixpoint(rewriter)
40+
end
41+
42+
x = PassThrough(f)(x)
43+
simplify_fractions && has_operation(x, /) ?
44+
SymbolicUtils.simplify_fractions(x) : x
45+
end
46+
47+
has_operation(x, op) = (istree(x) && (operation(x) == op ||
48+
any(a->has_operation(a, op),
49+
unsorted_arguments(x))))
50+
51+
Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)

src/api.jl renamed to src/substitute.jl

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,3 @@
1-
##### Numeric simplification
2-
3-
"""
4-
```julia
5-
simplify(x; expand=false,
6-
threaded=false,
7-
thread_subtree_cutoff=100,
8-
rewriter=nothing)
9-
```
10-
11-
Simplify an expression (`x`) by applying `rewriter` until there are no changes.
12-
`expand=true` applies [`expand`](/api/#expand) in the beginning of each fixpoint iteration.
13-
"""
14-
function simplify(x;
15-
expand=false,
16-
polynorm=nothing,
17-
threaded=false,
18-
thread_subtree_cutoff=100,
19-
rewriter=nothing)
20-
if polynorm !== nothing
21-
Base.depwarn("simplify(..; polynorm=$polynorm) is deprecated, use simplify(..; expand=$polynorm) instead",
22-
:simplify)
23-
end
24-
25-
f = if rewriter === nothing
26-
if threaded
27-
threaded_simplifier(thread_subtree_cutoff)
28-
elseif expand
29-
serial_expand_simplifier
30-
else
31-
serial_simplifier
32-
end
33-
else
34-
Fixpoint(rewriter)
35-
end
36-
37-
PassThrough(f)(x)
38-
end
39-
40-
Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)
411

422
"""
433
substitute(expr, dict; fold=true)

test/rulesets.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ end
106106
@eqtest simplify(exp(a) * a * exp(b)) == simplify(a*exp(a+b))
107107
end
108108

109+
@testset "simplify_fractions" begin
110+
@syms x y z
111+
@eqtest simplify(2*((y + z)/x) - 2*y/x - z/x*2) == 0
112+
end
113+
109114
@testset "Depth" begin
110115
@syms x
111116
R = Rewriters.Postwalk(Rewriters.Chain([@rule(sin(~x) => cos(~x)),

0 commit comments

Comments
 (0)