Skip to content

Commit 93bb2ea

Browse files
authored
Merge pull request #348 from JuliaSymbolics/s/div
Some more canonicalization in Div
2 parents 2c80498 + 1811082 commit 93bb2ea

File tree

7 files changed

+205
-23
lines changed

7 files changed

+205
-23
lines changed

src/polyform.jl

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export PolyForm, simplify_fractions
1+
export PolyForm, simplify_fractions, quick_cancel
22
using Bijections
33
using DynamicPolynomials: PolyVar
44

@@ -230,10 +230,10 @@ expand(expr) = PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse
230230

231231
function polyform_factors(d::Div, pvar2sym, sym2term)
232232
make(xs) = map(xs) do x
233-
if x isa Pow && arguments(x)[2] isa Integer && arguments(x)[2] > 0
233+
if x isa Pow && x.base isa Integer && x.exp > 0
234234
# here we do want to recurse one level, that's why it's wrong to just
235235
# use Fs = Union{typeof(+), typeof(*)} here.
236-
Pow(PolyForm(arguments(x)[1], pvar2sym, sym2term), arguments(x)[2])
236+
Pow(PolyForm(x.base, pvar2sym, sym2term), x.exp)
237237
else
238238
PolyForm(x, pvar2sym, sym2term)
239239
end
@@ -244,14 +244,14 @@ end
244244

245245
_mul(xs...) = all(isempty, xs) ? 1 : *(Iterators.flatten(xs)...)
246246

247-
function simplify_fractions(d::Div)
247+
function simplify_div(d::Div)
248248
d.simplified && return d
249249
ns, ds = polyform_factors(d, get_pvar2sym(), get_sym2term())
250250
ns, ds = rm_gcds(ns, ds)
251251
if all(_isone, ds)
252252
return isempty(ns) ? 1 : simplify_fractions(_mul(ns))
253253
else
254-
return Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds)), true)
254+
Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds)), true)
255255
end
256256
end
257257

@@ -269,12 +269,26 @@ Find `Div` nodes and simplify them by cancelling a set of factors of numerators
269269
and denominators. It may leave some expressions in `PolyForm` format.
270270
"""
271271
function simplify_fractions(x)
272+
x = Postwalk(quick_cancel)(x)
273+
274+
!needs_div_rules(x) && return x
275+
272276
isdiv(x) = x isa Div
273277

274-
rules = [@acrule ~a::isdiv + ~b::isdiv => add_divs(~a,~b)
275-
@rule ~x::isdiv => simplify_fractions(~x)]
278+
rules = [@rule ~x::isdiv => simplify_div(~x)
279+
@acrule ~a::isdiv + ~b::isdiv => add_divs(~a,~b)]
280+
281+
Fixpoint(Postwalk(Chain(rules)))(x)
282+
end
283+
284+
function needs_div_rules(x)
285+
(x isa Div && !(x.num isa Number) && !(x.den isa Number)) ||
286+
(istree(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||
287+
(istree(x) && any(needs_div_rules, unsorted_arguments(x)))
288+
end
276289

277-
Prewalk(RestartedChain(rules))(x)
290+
function has_div(x)
291+
return x isa Div || (istree(x) && any(has_div, unsorted_arguments(x)))
278292
end
279293

280294
flatten_pows(xs) = map(xs) do x
@@ -289,6 +303,110 @@ const MaybeGcd = Union{PolyForm, MP.AbstractPolynomialLike, Integer}
289303
_gcd(x::MaybeGcd, y::MaybeGcd) = (coefftype(x) <: Complex || coefftype(y) <: Complex) ? 1 : gcd(x, y)
290304
_gcd(x, y) = 1
291305

306+
307+
"""
308+
quick_cancel(d::Div)
309+
310+
Cancel out matching factors from numerator and denominator.
311+
This is not as effective as `simplify_fractions`, for example,
312+
it wouldn't simplify `(x^2 + 15 - 8x) / (x - 5)` to `(x - 3)`.
313+
But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`.
314+
Has optimized processes for `Mul` and `Pow` terms.
315+
"""
316+
quick_cancel(d::Div) = Div{symtype(d)}(quick_cancel(d.num, d.den)...)
317+
318+
quick_cancel(x) = x
319+
320+
quick_cancel(x, y) = isequal(x, y) ? (1,1) : (x, y)
321+
322+
function quick_cancel(x::Pow, y)
323+
x.exp isa Number || return (x, y)
324+
isequal(x.base, y) && x.exp >= 1 ? (Pow{symtype(x)}(x.base, x.exp - 1),1) : (x, y)
325+
end
326+
327+
quick_cancel(y, x::Pow) = reverse(quick_cancel(x,y))
328+
329+
function quick_cancel(x::Pow, y::Pow)
330+
if isequal(x.base, y.base)
331+
!(x.exp isa Number && y.exp isa Number) && return (x, y)
332+
if x.exp > y.exp
333+
return Pow{symtype(x)}(x.base, x.exp-y.exp), 1
334+
elseif x.exp == y.exp
335+
return 1, 1
336+
else # x.exp < y.exp
337+
return 1, Pow{symtype(y)}(y.base, y.exp-x.exp)
338+
end
339+
end
340+
return x, y
341+
end
342+
343+
function quick_cancel(x::Mul, y)
344+
if haskey(x.dict, y) && x.dict[y] >= 1
345+
d = copy(x.dict)
346+
if d[y] > 1
347+
d[y] -= 1
348+
elseif d[y] == 1
349+
delete!(d, y)
350+
else
351+
error("Can't reach")
352+
end
353+
354+
return Mul(symtype(x), x.coeff, d), 1
355+
else
356+
return x, y
357+
end
358+
end
359+
360+
function quick_cancel(x::Mul, y::Pow)
361+
y.exp isa Number || return (x, y)
362+
if haskey(x.dict, y.base)
363+
d = copy(x.dict)
364+
if x.dict[y.base] > y.exp
365+
d[y.base] -= y.exp
366+
den = 1
367+
elseif x.dict[y.base] == y.exp
368+
delete!(d, y.base)
369+
den = 1
370+
else
371+
den = Pow{symtype(y)}(y.base, y.exp-d[y.base])
372+
delete!(d, y.base)
373+
end
374+
return Mul(symtype(x), x.coeff, d), den
375+
else
376+
return x, y
377+
end
378+
end
379+
380+
quick_cancel(x::Pow, y::Mul) = reverse(quick_cancel(y,x))
381+
382+
quick_cancel(y, x::Mul) = reverse(quick_cancel(x,y))
383+
384+
function quick_cancel(x::Mul, y::Mul)
385+
num_dict, den_dict = _merge_div(x.dict, y.dict)
386+
Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict)
387+
end
388+
389+
function _merge_div(ndict, ddict)
390+
num = copy(ndict)
391+
den = copy(ddict)
392+
for (k, v) in den
393+
if haskey(num, k)
394+
nk = num[k]
395+
if nk > v
396+
num[k] -= v
397+
delete!(den, k)
398+
elseif nk == v
399+
delete!(num, k)
400+
delete!(den, k)
401+
else
402+
den[k] -= nk
403+
delete!(num, k)
404+
end
405+
end
406+
end
407+
num, den
408+
end
409+
292410
function rm_gcds(ns, ds)
293411
ns = flatten_pows(ns)
294412
ds = flatten_pows(ds)

src/types.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,37 @@ end
841841
Base.hash(x::Div, u::UInt64) = hash(x.num, hash(x.den, u))
842842
Base.isequal(x::Div, y::Div) = isequal(x.num, y.num) && isequal(x.den, y.den)
843843

844+
const Rat = Union{Rational, Integer}
845+
846+
ratcoeff(x) = false, NaN
847+
ratcoeff(x::Rat) = true, x
848+
ratcoeff(x::Mul) = ratcoeff(x.coeff)
849+
ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y
850+
ratio(x::Rat,y::Rat) = x//y
851+
function maybe_intcoeff(x::Mul)
852+
x.coeff isa Rational && isone(x.coeff.den) ? Setfield.@set!(x.coeff = x.coeff.num) : x
853+
end
854+
maybe_intcoeff(x::Rational) = isone(x.den) ? x.num : x
855+
maybe_intcoeff(x) = x
856+
844857
function (::Type{Div{T}})(n, d, simplified=false; metadata=nothing) where {T}
845-
@assert !(n isa AbstractArray)
846-
@assert !(d isa AbstractArray)
858+
_iszero(n) && return zero(typeof(n))
859+
_isone(d) && return n
860+
d isa Number && _isone(-d) && return -1 * n
861+
n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify
862+
863+
# GCD coefficient upon construction
864+
rat, nc = ratcoeff(n)
865+
if rat
866+
rat, dc = ratcoeff(d)
867+
if rat
868+
g = gcd(nc, dc) * sign(dc) # make denominator positive
869+
invdc = ratio(1, g)
870+
n = maybe_intcoeff(invdc * n)
871+
d = maybe_intcoeff(invdc * d)
872+
end
873+
end
874+
847875
Div{T, typeof(n), typeof(d), typeof(metadata)}(n, d, simplified, metadata)
848876
end
849877

test/basics.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end
162162
@test repr(-(a + b)) == "-a - b"
163163
@test repr((2a)^(-2a)) == "(2a)^(-2a)"
164164
@test repr(1/2a) == "1 / (2a)"
165-
@test repr(2/(2*a)) == "2 / (2a)"
165+
@test repr(2/(2*a)) == "1 / a"
166166
@test repr(Term(*, [1, 1])) == "1"
167167
@test repr(Term(*, [2, 1])) == "2*1"
168168
@test repr((a + b) - (b + c)) == "a - c"
@@ -226,3 +226,15 @@ end
226226
@syms a b c::Int
227227
@test isequal(arguments(s(a, b, c)), [a, b, c])
228228
end
229+
230+
@testset "div" begin
231+
@syms x y
232+
@test (2x/2y).num isa Sym
233+
@test (2x/3y).num.coeff == 2
234+
@test (2x/3y).den.coeff == 3
235+
@test (2x/-3x).num.coeff == -2
236+
@test (2x/-3x).den.coeff == 3
237+
@test (2.5x/3x).num.coeff == 2.5
238+
@test (2.5x/3x).den.coeff == 3
239+
@test (x/3x).den.coeff == 3
240+
end

test/polyform.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
using SymbolicUtils: PolyForm, Term, symtype
22
using Test, SymbolicUtils
33

4+
include("utils.jl")
5+
46
@testset "div and polyform" begin
57
@syms x y z
68
@test repr(PolyForm(x-y)) == "x - y"
79
@test repr(x/y*x/z) == "(x^2) / (y*z)"
810
@test repr(simplify_fractions(((x-y+z)*(x+4z+1)) /
911
(y*(2x - 3y + 3z) +
10-
x*(x + z)))) == "(1 + x + 4z) / (x + 3.0y)"
12+
x*(x + z)))) == repr(simplify_fractions((1 + x + 4z) / (x + 3.0y)))
1113
@test simplify_fractions(x/(x+3) + 3/(x+3)) == 1
1214
@test repr(simplify(simplify_fractions(cos(x)/sin(x) + sin(x)/cos(x)))) == "1 / (cos(x)*sin(x))"
1315
end
@@ -31,3 +33,23 @@ end
3133
#@test expand(identity(a * b) - b * a) == 0
3234
@test expand(a * b - b * a) == 0
3335
end
36+
37+
@testset "simplify_fractions with quick-cancel" begin
38+
@syms x y
39+
@test simplify_fractions(x/2x) == 1//2
40+
@test simplify_fractions(2x//x) == 2
41+
@eqtest simplify_fractions((x+y) * (x-y) / (x+y)) == (x-y)
42+
@eqtest simplify_fractions(x^3 * y / x) == y*x^2
43+
@eqtest simplify_fractions(2x^3 * y / x) == 2y*x^2
44+
@eqtest simplify_fractions(x / (3(x^3)*y)) == simplify_fractions(1/(3*(y*x^2)))
45+
@eqtest simplify_fractions(2x / (3(x^3)*y)) == simplify_fractions(2/(3*(y*x^2)))
46+
@eqtest simplify_fractions(x^2 / (3(x^3)*y)) == simplify_fractions(1/(3*(y*x)))
47+
@eqtest simplify_fractions((3(x^3)*y) / x^2) == simplify_fractions(3*(y*x))
48+
@eqtest simplify_fractions(x^2/x^4) == (1/x^2)
49+
@eqtest simplify_fractions(x^2/x^3) == 1/x
50+
@eqtest simplify_fractions(x^3/x^2) == x
51+
@eqtest simplify_fractions(x^2/x^2) == 1
52+
@eqtest simplify_fractions(3x^2/x^3) == 3/x
53+
@eqtest simplify_fractions(3*(x^2)*(y^3)/(3*(x^3)*(y^2))) == y/x
54+
@eqtest simplify_fractions(3*(x^x)/x*y) == 3*(x^x)/x*y
55+
end

test/rulesets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
@eqtest simplify(a + b + 0*c + d) == simplify(a + b + d)
4343
@eqtest simplify(a * b * c^0 * d) == simplify(a * b * d)
4444
@eqtest simplify(a * b * 1*c * d) == simplify(a * b * c * d)
45-
@eqtest repr(simplify_fractions(x^2.0/(x*y)^2.0)) == "1 / (y^2.0)"
45+
@eqtest simplify_fractions(x^2.0/(x*y)^2.0) == simplify_fractions(1 / (y^2.0))
4646

4747
@test simplify(Term(one, [a])) == 1
4848
@test simplify(Term(one, [b+1])) == 1

test/runtests.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,10 @@ if v"1.6" ≤ VERSION < v"1.7-beta3.0"
1717
else
1818
@warn "Skipping doctests"
1919
end
20-
21-
# == / != syntax is nice, let's use it in tests
22-
macro eqtest(expr)
23-
@assert expr.head == :call && expr.args[1] in [:(==), :(!=)]
24-
if expr.args[1] == :(==)
25-
:(@test isequal($(expr.args[2]), $(expr.args[3])))
26-
else
27-
:(@test !isequal($(expr.args[2]), $(expr.args[3])))
28-
end |> esc
29-
end
3020
SymbolicUtils.show_simplified[] = false
3121

22+
include("utils.jl")
23+
3224
if haskey(ENV, "SU_BENCHMARK_ONLY")
3325
include("benchmark.jl")
3426
else

test/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
# == / != syntax is nice, let's use it in tests
3+
macro eqtest(expr)
4+
@assert expr.head == :call && expr.args[1] in [:(==), :(!=)]
5+
if expr.args[1] == :(==)
6+
:(@test isequal($(expr.args[2]), $(expr.args[3])))
7+
else
8+
:(@test !isequal($(expr.args[2]), $(expr.args[3])))
9+
end |> esc
10+
end

0 commit comments

Comments
 (0)