Skip to content

Commit 3c4d76b

Browse files
authored
Merge pull request #351 from JuliaSymbolics/s/flatten_frac
`flatten_fraction`
2 parents 93bb2ea + c572b74 commit 3c4d76b

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

src/polyform.jl

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export PolyForm, simplify_fractions, quick_cancel
1+
export PolyForm, simplify_fractions, quick_cancel, flatten_fractions
22
using Bijections
33
using DynamicPolynomials: PolyVar
44

@@ -7,7 +7,8 @@ using DynamicPolynomials: PolyVar
77
88
Abstracts a [MultivariatePolynomials.jl](https://juliaalgebra.github.io/MultivariatePolynomials.jl/stable/) as a SymbolicUtils expression and vice-versa.
99
10-
The SymbolicUtils term interface (`istree`, `operation, and `arguments`) works on PolyForm lazily: the `operation` and `arguments` are created by converting one level of arguments into SymbolicUtils expressions. They may further contain PolyForm within them.
10+
The SymbolicUtils term interface (`istree`, `operation, and `arguments`) works on PolyForm lazily:
11+
the `operation` and `arguments` are created by converting one level of arguments into SymbolicUtils expressions. They may further contain PolyForm within them.
1112
We use this to hold polynomials in memory while doing `simplify_fractions`.
1213
1314
PolyForm{T}(x; Fs=Union{typeof(*),typeof(+),typeof(^)}, recurse=false)
@@ -190,15 +191,16 @@ function TermInterface.arguments(x::PolyForm{T}) where {T}
190191

191192
if MP.nterms(x.p) == 1
192193
MP.isconstant(x.p) && return [convert(Number, x.p)]
193-
c = MP.coefficient(x.p)
194-
t = MP.monomial(x.p)
194+
t = MP.term(x.p)
195+
c = MP.coefficient(t)
196+
m = MP.monomial(t)
195197

196198
if !isone(c)
197199
[c, (unstable_pow(resolve(v), pow)
198-
for (v, pow) in MP.powers(t) if !iszero(pow))...]
200+
for (v, pow) in MP.powers(m) if !iszero(pow))...]
199201
else
200202
[unstable_pow(resolve(v), pow)
201-
for (v, pow) in MP.powers(t) if !iszero(pow)]
203+
for (v, pow) in MP.powers(m) if !iszero(pow)]
202204
end
203205
else
204206
ts = MP.terms(x.p)
@@ -223,12 +225,12 @@ Expand expressions by distributing multiplication over addition, e.g.,
223225
multivariate polynomials implementation.
224226
`variable_type` can be any subtype of `MultivariatePolynomials.AbstractVariable`.
225227
"""
226-
expand(expr) = PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true)
228+
expand(expr) = Postwalk(identity)(PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true))
227229

228230

229231
## Rational Polynomial form with Div
230232

231-
function polyform_factors(d::Div, pvar2sym, sym2term)
233+
function polyform_factors(d, pvar2sym, sym2term)
232234
make(xs) = map(xs) do x
233235
if x isa Pow && x.base isa Integer && x.exp > 0
234236
# here we do want to recurse one level, that's why it's wrong to just
@@ -255,11 +257,11 @@ function simplify_div(d::Div)
255257
end
256258
end
257259

258-
function add_divs(x::Div, y::Div)
260+
function add_divs(x, y)
259261
x_num, x_den = polyform_factors(x, get_pvar2sym(), get_sym2term())
260262
y_num, y_den = polyform_factors(y, get_pvar2sym(), get_sym2term())
261263

262-
Div(_mul(x_num, y_den) + _mul(x_den, y_num), _mul(x_den, y_den))
264+
(_mul(x_num, y_den) + _mul(x_den, y_num)) / (_mul(x_den, y_den))
263265
end
264266

265267
"""
@@ -278,9 +280,38 @@ function simplify_fractions(x)
278280
rules = [@rule ~x::isdiv => simplify_div(~x)
279281
@acrule ~a::isdiv + ~b::isdiv => add_divs(~a,~b)]
280282

283+
Fixpoint(Postwalk(RestartedChain(rules)))(x)
284+
end
285+
286+
"""
287+
flatten_fractions(x)
288+
289+
Flatten nested fractions that are added together.
290+
291+
```julia
292+
julia> flatten_fractions((1+(1+1/a)/a)/a)
293+
(1 + a + a^2) / (a^3)
294+
```
295+
"""
296+
function flatten_fractions(x)
297+
rules = [@acrule ~a::(x->x isa Div) + ~b => add_divs(~a,~b)
298+
@rule *(~~x, ~a / ~b, ~~y) / ~c => *((~~x)..., ~a, (~~y)...) / (~b * ~c)
299+
@rule ~c / *(~~x, ~a / ~b, ~~y) => (~b * ~c) / *((~~x)..., ~a, (~~y)...)]
281300
Fixpoint(Postwalk(Chain(rules)))(x)
282301
end
283302

303+
function fraction_iszero(x)
304+
!istree(x) && return _iszero(x)
305+
# fast path and then slow path
306+
any(_iszero, numerators(flatten_fractions(x))) ||
307+
any(_iszeroexpand, numerators(flatten_fractions(x)))
308+
end
309+
310+
function fraction_isone(x)
311+
!istree(x) && return _isone(x)
312+
_isone(simplify_fractions(flatten_fractions(x)))
313+
end
314+
284315
function needs_div_rules(x)
285316
(x isa Div && !(x.num isa Number) && !(x.den isa Number)) ||
286317
(istree(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||

src/types.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,15 @@ maybe_intcoeff(x) = x
857857
function (::Type{Div{T}})(n, d, simplified=false; metadata=nothing) where {T}
858858
_iszero(n) && return zero(typeof(n))
859859
_isone(d) && return n
860+
861+
if n isa Div && d isa Div
862+
return Div{T}(n.num * d.den, n.den * d.num)
863+
elseif n isa Div
864+
return Div{T}(n.num, n.den * d)
865+
elseif d isa Div
866+
return Div{T}(n * d.den, d.num)
867+
end
868+
860869
d isa Number && _isone(-d) && return -1 * n
861870
n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify
862871

@@ -879,15 +888,11 @@ function Div(n,d, simplified=false; kw...)
879888
Div{promote_symtype((/), symtype(n), symtype(d))}(n,d, simplified; kw...)
880889
end
881890

882-
function numerators(d::Div)
883-
x = d.num
884-
istree(x) && operation(x) == (*) ? arguments(x) : [x]
885-
end
891+
numerators(x) = istree(x) && operation(x) == (*) ? arguments(x) : [x]
892+
numerators(d::Div) = numerators(d.num)
886893

887-
function denominators(d::Div)
888-
x = d.den
889-
istree(x) && operation(x) == (*) ? arguments(x) : [x]
890-
end
894+
denominators(x) = [1]
895+
denominators(d::Div) = numerators(d.den)
891896

892897
TermInterface.istree(d::Type{Div}) = true
893898

@@ -899,17 +904,7 @@ end
899904

900905
Base.show(io::IO, d::Div) = show_term(io, d)
901906

902-
function /(a::Union{SN,Number}, b::SN)
903-
if a isa Div && b isa Div
904-
Div(a.num * b.den, a.den * b.num)
905-
elseif a isa Div
906-
Div(a.num, a.den * b)
907-
elseif b isa Div
908-
Div(a * b.den, b.num)
909-
else
910-
Div(a,b)
911-
end
912-
end
907+
/(a::Union{SN,Number}, b::SN) = Div(a,b)
913908

914909
"""
915910
Pow(base, exp)

test/polyform.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,16 @@ end
5353
@eqtest simplify_fractions(3*(x^2)*(y^3)/(3*(x^3)*(y^2))) == y/x
5454
@eqtest simplify_fractions(3*(x^x)/x*y) == 3*(x^x)/x*y
5555
end
56+
57+
@testset "isone iszero" begin
58+
@syms a b c d e f g h i
59+
x = (f + ((((g*(c^2)*(e^2)) / d - e*h*(c^2)) / b + (-c*e*f*g) / d + c*e*i) /
60+
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d) - c*e) / b +
61+
((g*(f^2)) / d + ((-c*e*f*g) / d + c*f*h) / b - f*i) /
62+
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d)) / d
63+
64+
o = (d + (e*((c*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d))) / b + (-f*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d)) / d
65+
@test SymbolicUtils.fraction_iszero(x)
66+
@test !SymbolicUtils.fraction_isone(x)
67+
@test SymbolicUtils.fraction_isone(o)
68+
end

0 commit comments

Comments
 (0)