Skip to content

Commit 102253f

Browse files
committed
speed up flatten_fraction
1 parent cf2861d commit 102253f

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

src/polyform.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,9 @@ function simplify_div(d::Div)
257257
end
258258
end
259259

260-
function add_divs(x, y)
261-
x_num, x_den = polyform_factors(x, get_pvar2sym(), get_sym2term())
262-
y_num, y_den = polyform_factors(y, get_pvar2sym(), get_sym2term())
263-
264-
(_mul(x_num, y_den) + _mul(x_den, y_num)) / (_mul(x_den, y_den))
265-
end
260+
add_divs(x::Div, y::Div) = (x.num * y.den + y.num * x.den) / (x.den * y.den)
261+
add_divs(x::Div, y) = (x.num + y * x.den) / x.den
262+
add_divs(x, y::Div) = (x * y.den + y.num) / y.den
266263

267264
"""
268265
simplify_fractions(x)
@@ -283,6 +280,17 @@ function simplify_fractions(x)
283280
Fixpoint(Postwalk(RestartedChain(rules)))(x)
284281
end
285282

283+
function add_with_div(x)
284+
(!istree(x) || operation(x) != (+)) && return nothing
285+
aa = unsorted_arguments(x)
286+
!any(a->a isa Div, aa) && return nothing # no rewrite necessary
287+
288+
divs = filter(a->a isa Div, aa)
289+
nondivs = filter(a->!(a isa Div), aa)
290+
nds = isempty(nondivs) ? 0 : +(nondivs...)
291+
292+
return quick_cancel(add_divs(reduce(quick_canceladd_divs, divs), nds))
293+
end
286294
"""
287295
flatten_fractions(x)
288296
@@ -294,17 +302,15 @@ julia> flatten_fractions((1+(1+1/a)/a)/a)
294302
```
295303
"""
296304
function flatten_fractions(x)
297-
rules = [@acrule ~a::(x->x isa Div) + ~b => add_divs(~a,~b)
298-
@rule *(~~x, ~a / ~b, ~~y) / ~c => *((~~x)..., ~a, (~~y)...) / (~b * ~c)
299-
@rule ~c / *(~~x, ~a / ~b, ~~y) => (~b * ~c) / *((~~x)..., ~a, (~~y)...)]
300-
Fixpoint(Postwalk(Chain(rules)))(x)
305+
Fixpoint(Postwalk(PassThrough(add_with_div)))(x)
301306
end
302307

303308
function fraction_iszero(x)
304309
!istree(x) && return _iszero(x)
310+
ff = flatten_fractions(x)
305311
# fast path and then slow path
306-
any(_iszero, numerators(flatten_fractions(x))) ||
307-
any(_iszeroexpand, numerators(flatten_fractions(x)))
312+
any(_iszero, numerators(ff)) ||
313+
any(_iszeroexpand, numerators(ff))
308314
end
309315

310316
function fraction_isone(x)

src/rewriters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ rewriters.
2929
3030
"""
3131
module Rewriters
32-
using SymbolicUtils: @timer
32+
using SymbolicUtils: @timer, unsorted_arguments
3333
using TermInterface: is_operation, istree, operation, similarterm, arguments, node_count
3434

3535
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
@@ -159,7 +159,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
159159
x = p.rw(x)
160160
end
161161
if istree(x)
162-
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
162+
x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)))
163163
end
164164
return ord === :post ? p.rw(x) : x
165165
else

src/rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ function (acr::ACRule)(term)
356356
end
357357

358358
T = symtype(term)
359-
args = arguments(term)
359+
args = unsorted_arguments(term)
360360

361361
itr = acr.sets(eachindex(args), acr.arity)
362362

0 commit comments

Comments
 (0)