Skip to content

Commit 369eda5

Browse files
authored
Merge pull request #353 from JuliaSymbolics/s/frac-perf
speed up `flatten_fraction`
2 parents cf2861d + 8aaa7e0 commit 369eda5

File tree

4 files changed

+42
-20
lines changed

4 files changed

+42
-20
lines changed

benchmark/benchmarks.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,22 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
6666
overhead["substitute"]["a,b,c"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2, c=>3))) setup=begin
6767
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
6868
end
69+
70+
71+
end
72+
73+
let
74+
pform = SUITE["polyform"] = BenchmarkGroup()
75+
@syms a b c d e f g h i
76+
ex = (f + ((((g*(c^2)*(e^2)) / d - e*h*(c^2)) / b + (-c*e*f*g) / d + c*e*i) /
77+
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d) - c*e) / b +
78+
((g*(f^2)) / d + ((-c*e*f*g) / d + c*f*h) / b - f*i) /
79+
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d)) / d
80+
81+
o = (d + (e*((c*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d))) / b +
82+
(-f*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d)) / d
83+
pform["simplify_fractions"] = @benchmarkable simplify_fractions($ex)
84+
pform["iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($ex)
85+
pform["isone"] = @benchmarkable SymbolicUtils.fraction_isone($o)
86+
pform["easy_iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($((b*(h + (-e*g) / d)) / b + (e*g) / d - h))
6987
end

src/polyform.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,10 @@ function simplify_div(d::Div)
257257
end
258258
end
259259

260-
function add_divs(x, y)
261-
x_num, x_den = polyform_factors(x, get_pvar2sym(), get_sym2term())
262-
y_num, y_den = polyform_factors(y, get_pvar2sym(), get_sym2term())
263-
264-
(_mul(x_num, y_den) + _mul(x_den, y_num)) / (_mul(x_den, y_den))
265-
end
260+
add_divs(x::Div, y::Div) = (x.num * y.den + y.num * x.den) / (x.den * y.den)
261+
add_divs(x::Div, y) = (x.num + y * x.den) / x.den
262+
add_divs(x, y::Div) = (x * y.den + y.num) / y.den
263+
add_divs(x, y) = x + y
266264

267265
"""
268266
simplify_fractions(x)
@@ -275,14 +273,22 @@ function simplify_fractions(x)
275273

276274
!needs_div_rules(x) && return x
277275

278-
isdiv(x) = x isa Div
279-
280-
rules = [@rule ~x::isdiv => simplify_div(~x)
281-
@acrule ~a::isdiv + ~b::isdiv => add_divs(~a,~b)]
276+
sdiv(a) = a isa Div ? simplify_div(a) : a
282277

283-
Fixpoint(Postwalk(RestartedChain(rules)))(x)
278+
Postwalk(sdiv quick_cancel)(Postwalk(add_with_div)(x))
284279
end
285280

281+
function add_with_div(x, flatten=true)
282+
(!istree(x) || operation(x) != (+)) && return x
283+
aa = unsorted_arguments(x)
284+
!any(a->a isa Div, aa) && return x # no rewrite necessary
285+
286+
divs = filter(a->a isa Div, aa)
287+
nondivs = filter(a->!(a isa Div), aa)
288+
nds = isempty(nondivs) ? 0 : +(nondivs...)
289+
d = reduce(quick_canceladd_divs, divs)
290+
flatten ? quick_cancel(add_divs(d, nds)) : d + nds
291+
end
286292
"""
287293
flatten_fractions(x)
288294
@@ -294,17 +300,15 @@ julia> flatten_fractions((1+(1+1/a)/a)/a)
294300
```
295301
"""
296302
function flatten_fractions(x)
297-
rules = [@acrule ~a::(x->x isa Div) + ~b => add_divs(~a,~b)
298-
@rule *(~~x, ~a / ~b, ~~y) / ~c => *((~~x)..., ~a, (~~y)...) / (~b * ~c)
299-
@rule ~c / *(~~x, ~a / ~b, ~~y) => (~b * ~c) / *((~~x)..., ~a, (~~y)...)]
300-
Fixpoint(Postwalk(Chain(rules)))(x)
303+
Fixpoint(Postwalk(add_with_div))(x)
301304
end
302305

303306
function fraction_iszero(x)
304307
!istree(x) && return _iszero(x)
308+
ff = flatten_fractions(x)
305309
# fast path and then slow path
306-
any(_iszero, numerators(flatten_fractions(x))) ||
307-
any(_iszeroexpand, numerators(flatten_fractions(x)))
310+
any(_iszero, numerators(ff)) ||
311+
any(_iszeroexpand, numerators(ff))
308312
end
309313

310314
function fraction_isone(x)

src/rewriters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ rewriters.
2929
3030
"""
3131
module Rewriters
32-
using SymbolicUtils: @timer
32+
using SymbolicUtils: @timer, unsorted_arguments
3333
using TermInterface: is_operation, istree, operation, similarterm, arguments, node_count
3434

3535
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
@@ -159,7 +159,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
159159
x = p.rw(x)
160160
end
161161
if istree(x)
162-
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
162+
x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)))
163163
end
164164
return ord === :post ? p.rw(x) : x
165165
else

src/rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ function (acr::ACRule)(term)
356356
end
357357

358358
T = symtype(term)
359-
args = arguments(term)
359+
args = unsorted_arguments(term)
360360

361361
itr = acr.sets(eachindex(args), acr.arity)
362362

0 commit comments

Comments
 (0)