Skip to content

Commit c93abab

Browse files
committed
Make add_div more general, add flatten_fraction function.
1 parent 93bb2ea commit c93abab

File tree

2 files changed

+40
-29
lines changed

2 files changed

+40
-29
lines changed

src/polyform.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export PolyForm, simplify_fractions, quick_cancel
1+
export PolyForm, simplify_fractions, quick_cancel, flatten_fraction
22
using Bijections
33
using DynamicPolynomials: PolyVar
44

@@ -7,7 +7,8 @@ using DynamicPolynomials: PolyVar
77
88
Abstracts a [MultivariatePolynomials.jl](https://juliaalgebra.github.io/MultivariatePolynomials.jl/stable/) as a SymbolicUtils expression and vice-versa.
99
10-
The SymbolicUtils term interface (`istree`, `operation, and `arguments`) works on PolyForm lazily: the `operation` and `arguments` are created by converting one level of arguments into SymbolicUtils expressions. They may further contain PolyForm within them.
10+
The SymbolicUtils term interface (`istree`, `operation, and `arguments`) works on PolyForm lazily:
11+
the `operation` and `arguments` are created by converting one level of arguments into SymbolicUtils expressions. They may further contain PolyForm within them.
1112
We use this to hold polynomials in memory while doing `simplify_fractions`.
1213
1314
PolyForm{T}(x; Fs=Union{typeof(*),typeof(+),typeof(^)}, recurse=false)
@@ -190,15 +191,16 @@ function TermInterface.arguments(x::PolyForm{T}) where {T}
190191

191192
if MP.nterms(x.p) == 1
192193
MP.isconstant(x.p) && return [convert(Number, x.p)]
193-
c = MP.coefficient(x.p)
194-
t = MP.monomial(x.p)
194+
t = MP.term(x.p)
195+
c = MP.coefficient(t)
196+
m = MP.monomial(t)
195197

196198
if !isone(c)
197199
[c, (unstable_pow(resolve(v), pow)
198-
for (v, pow) in MP.powers(t) if !iszero(pow))...]
200+
for (v, pow) in MP.powers(m) if !iszero(pow))...]
199201
else
200202
[unstable_pow(resolve(v), pow)
201-
for (v, pow) in MP.powers(t) if !iszero(pow)]
203+
for (v, pow) in MP.powers(m) if !iszero(pow)]
202204
end
203205
else
204206
ts = MP.terms(x.p)
@@ -228,7 +230,7 @@ expand(expr) = PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse
228230

229231
## Rational Polynomial form with Div
230232

231-
function polyform_factors(d::Div, pvar2sym, sym2term)
233+
function polyform_factors(d, pvar2sym, sym2term)
232234
make(xs) = map(xs) do x
233235
if x isa Pow && x.base isa Integer && x.exp > 0
234236
# here we do want to recurse one level, that's why it's wrong to just
@@ -255,11 +257,11 @@ function simplify_div(d::Div)
255257
end
256258
end
257259

258-
function add_divs(x::Div, y::Div)
260+
function add_divs(x, y)
259261
x_num, x_den = polyform_factors(x, get_pvar2sym(), get_sym2term())
260262
y_num, y_den = polyform_factors(y, get_pvar2sym(), get_sym2term())
261263

262-
Div(_mul(x_num, y_den) + _mul(x_den, y_num), _mul(x_den, y_den))
264+
(_mul(x_num, y_den) + _mul(x_den, y_num)) / (_mul(x_den, y_den))
263265
end
264266

265267
"""
@@ -278,7 +280,21 @@ function simplify_fractions(x)
278280
rules = [@rule ~x::isdiv => simplify_div(~x)
279281
@acrule ~a::isdiv + ~b::isdiv => add_divs(~a,~b)]
280282

281-
Fixpoint(Postwalk(Chain(rules)))(x)
283+
Fixpoint(Postwalk(RestartedChain(rules)))(x)
284+
end
285+
286+
"""
287+
flatten_fraction(x)
288+
289+
Flatten nested fractions that are added together.
290+
291+
```julia
292+
julia> flatten_fraction((1+(1+1/a)/a)/a)
293+
(1 + a + a^2) / (a^3)
294+
```
295+
"""
296+
function flatten_fraction(x)
297+
Fixpoint(Postwalk(PassThrough(@acrule ~a::(x->x isa Div) + ~b => add_divs(~a,~b))))(x)
282298
end
283299

284300
function needs_div_rules(x)

src/types.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,15 @@ maybe_intcoeff(x) = x
857857
function (::Type{Div{T}})(n, d, simplified=false; metadata=nothing) where {T}
858858
_iszero(n) && return zero(typeof(n))
859859
_isone(d) && return n
860+
861+
if n isa Div && d isa Div
862+
return Div{T}(n.num * d.den, n.den * d.num)
863+
elseif n isa Div
864+
return Div{T}(n.num, n.den * d)
865+
elseif d isa Div
866+
return Div{T}(n * d.den, d.num)
867+
end
868+
860869
d isa Number && _isone(-d) && return -1 * n
861870
n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify
862871

@@ -879,15 +888,11 @@ function Div(n,d, simplified=false; kw...)
879888
Div{promote_symtype((/), symtype(n), symtype(d))}(n,d, simplified; kw...)
880889
end
881890

882-
function numerators(d::Div)
883-
x = d.num
884-
istree(x) && operation(x) == (*) ? arguments(x) : [x]
885-
end
891+
numerators(x) = istree(x) && operation(x) == (*) ? arguments(x) : [x]
892+
numerators(d::Div) = numerators(d.num)
886893

887-
function denominators(d::Div)
888-
x = d.den
889-
istree(x) && operation(x) == (*) ? arguments(x) : [x]
890-
end
894+
denominators(x) = [1]
895+
denominators(d::Div) = numerators(d.den)
891896

892897
TermInterface.istree(d::Type{Div}) = true
893898

@@ -899,17 +904,7 @@ end
899904

900905
Base.show(io::IO, d::Div) = show_term(io, d)
901906

902-
function /(a::Union{SN,Number}, b::SN)
903-
if a isa Div && b isa Div
904-
Div(a.num * b.den, a.den * b.num)
905-
elseif a isa Div
906-
Div(a.num, a.den * b)
907-
elseif b isa Div
908-
Div(a * b.den, b.num)
909-
else
910-
Div(a,b)
911-
end
912-
end
907+
/(a::Union{SN,Number}, b::SN) = Div(a,b)
913908

914909
"""
915910
Pow(base, exp)

0 commit comments

Comments
 (0)