Skip to content

Commit 0112c2c

Browse files
authored
Merge pull request #271 from blegat/multivariate_polynomials
Allow any implementation of the MultivariatePolynomials API
2 parents 1400c0f + 9310bef commit 0112c2c

File tree

5 files changed

+91
-51
lines changed

5 files changed

+91
-51
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "0.12.1"
4+
version = "0.13.0"
55

66
[deps]
7-
AbstractAlgebra = "c3fe647b-3220-5bb0-a1ea-a7954cac585d"
87
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
98
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1110
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1211
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
12+
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
1313
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1414
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
1617
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1718
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1819
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -21,14 +22,15 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2223

2324
[compat]
24-
AbstractAlgebra = "0.9, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18"
2525
AbstractTrees = "0.3"
2626
ChainRulesCore = "0.9, 0.10"
2727
Combinatorics = "1.0"
2828
ConstructionBase = "1.1"
2929
DataStructures = "0.18"
30+
DynamicPolynomials = "0.3"
3031
IfElse = "0.1"
3132
LabelledArrays = "1.5"
33+
MultivariatePolynomials = "0.3"
3234
NaNMath = "0.3"
3335
Setfield = "0.7"
3436
SpecialFunctions = "0.10, 1.0"

benchmark/goldstein_price.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using SymbolicUtils, Test
2+
using SymbolicUtils: Term, symtype
3+
using BenchmarkTools
4+
5+
function goldstein_price()
6+
@syms x1 x2
7+
f1 = x1+x2+1
8+
f2 = 19-14*x1+3*x1^2-14*x2+6*x1*x2+3*x2^2
9+
f3 = 2*x1-3*x2
10+
f4 = 18-32*x1+12*x1^2+48*x2-36*x1*x2+27*x2^2
11+
12+
# f(x) is the Goldstein-Price function
13+
f = (1+f1^2*f2)*(30+f3^2*f4)
14+
@btime expand($f)
15+
f2 = f^2
16+
@btime expand($f2)
17+
f3 = f^3
18+
@btime expand($f3)
19+
end

src/SymbolicUtils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ include("rule.jl")
3636
include("matchers.jl")
3737

3838
# Convert to an efficient multi-variate polynomial representation
39-
import AbstractAlgebra.Generic: MPoly, PolynomialRing, ZZ, exponent_vector
40-
using AbstractAlgebra: ismonomial, symbols
39+
import MultivariatePolynomials
40+
const MP = MultivariatePolynomials
41+
import DynamicPolynomials
4142
export expand
4243
include("abstractalgebra.jl")
4344

src/abstractalgebra.jl

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ with a symbol, store the symbol => term mapping in `dict`.
99
function labels! end
1010

1111
# Turn a Term into a multivariate polynomial
12-
function labels!(dicts, t::Sym)
12+
function labels!(dicts, t::Sym, variable_type::Type)
1313
sym2term, term2sym = dicts
1414
if !haskey(term2sym, t)
1515
sym2term[t] = t
@@ -18,14 +18,14 @@ function labels!(dicts, t::Sym)
1818
return t
1919
end
2020

21-
function labels!(dicts, t)
22-
if t isa Integer
21+
function labels!(dicts, t, variable_type::Type)
22+
if t isa Number
2323
return t
2424
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
2525
tt = arguments(t)
26-
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt), symtype(t))
26+
return similarterm(t, operation(t), map(x->labels!(dicts, x, variable_type), tt), symtype(t))
2727
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
28-
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)), symtype(t))
28+
return similarterm(t, operation(t), map(x->labels!(dicts, x, variable_type), arguments(t)), symtype(t))
2929
else
3030
sym2term, term2sym = dicts
3131
if haskey(term2sym, t)
@@ -36,7 +36,7 @@ function labels!(dicts, t)
3636
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
3737
dicts2 = _dicts(dicts[2])
3838
sym2term[sym] = similarterm(t, operation(t),
39-
map(x->to_mpoly(x, dicts)[1], arguments(t)),
39+
map(x->to_mpoly(x, variable_type, dicts)[1], arguments(t)),
4040
symtype(t))
4141
else
4242
sym = Sym{symtype(t)}(gensym("literal"))
@@ -49,7 +49,7 @@ function labels!(dicts, t)
4949
end
5050
end
5151

52-
ismpoly(x) = x isa MPoly || x isa Integer
52+
ismpoly(x) = x isa MP.AbstractPolynomialLike || x isa Number
5353
isnonnegint(x) = x isa Integer && x >= 0
5454

5555
_dicts(t2s=OrderedDict{Any, Sym}()) = (OrderedDict{Sym, Any}(), t2s)
@@ -71,70 +71,82 @@ let
7171
MPOLY_MAKER = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_rules)), similarterm=simterm))
7272

7373
global to_mpoly
74-
function to_mpoly(t, dicts=_dicts())
74+
function to_mpoly(t, variable_type::Type=DynamicPolynomials.PolyVar{true}, dicts=_dicts())
7575
# term2sym is only used to assign the same
7676
# symbol for the same term -- in other words,
7777
# it does common subexpression elimination
7878
t = MPOLY_CLEANUP(t)
7979
sym2term, term2sym = dicts
80-
labeled = labels!((sym2term, term2sym), t)
80+
labeled = labels!((sym2term, term2sym), t, variable_type)
8181

8282
if isempty(sym2term)
8383
return MPOLY_MAKER(labeled), Dict{Sym,Any}()
8484
end
8585

8686
ks = sort(collect(keys(sym2term)), lt=<ₑ)
87-
R, vars = PolynomialRing(ZZ, String.(nameof.(ks)))
87+
vars = MP.similarvariable.(variable_type, nameof.(ks))
8888

89-
replace_with_poly = Dict{Sym,MPoly}(zip(ks, vars))
89+
replace_with_poly = Dict{Sym,eltype(vars)}(zip(ks, vars))
9090
t_poly = substitute(labeled, replace_with_poly, fold=false)
9191
MPOLY_MAKER(t_poly), sym2term
9292
end
9393
end
9494

9595
function to_term(reference, x, dict)
96-
syms = Dict(zip(nameof.(keys(dict)), keys(dict)))
96+
syms = Dict(zip(string.(nameof.(keys(dict))), keys(dict)))
9797
dict = copy(dict)
9898
for (k, v) in dict
9999
dict[k] = _to_term(reference, v, dict, syms)
100100
end
101-
_to_term(reference, x, dict, syms)
101+
return _to_term(reference, x, dict, syms)
102+
#return substitute(t, dict, fold=false)
102103
end
103104

104-
function _to_term(reference, x::MPoly, dict, syms)
105-
106-
function mul_coeffs(exps, ring)
107-
l = length(syms)
108-
ss = symbols(ring)
109-
monics = [e == 1 ? syms[ss[i]] : syms[ss[i]]^e for (i, e) in enumerate(exps) if !iszero(e)]
110-
if length(monics) == 1
111-
return monics[1]
112-
elseif length(monics) == 0
113-
return 1
114-
else
115-
return similarterm(reference, *, monics, symtype(reference))
105+
_to_term(reference, x::Number, dict, syms) = x
106+
_to_term(reference, var::MP.AbstractVariable, dict, syms) = substitute(syms[MP.name(var)], dict, fold=false)
107+
function _to_term(reference, mono::MP.AbstractMonomialLike, dict, syms)
108+
monics = [
109+
begin
110+
t = _to_term(reference, var, dict, syms)
111+
exp == 1 ? t : t^exp
116112
end
113+
for (var, exp) in MP.powers(mono) if !iszero(exp)
114+
]
115+
if length(monics) == 1
116+
return monics[1]
117+
elseif isempty(monics)
118+
return 1
119+
else
120+
return similarterm(reference, *, monics, symtype(reference))
117121
end
122+
end
118123

119-
monoms = [mul_coeffs(exponent_vector(x, i), x.parent) for i in 1:x.length]
120-
if length(monoms) == 0
121-
return 0
122-
elseif length(monoms) == 1
123-
t = !isone(x.coeffs[1]) ? monoms[1] * Int(x.coeffs[1]) : monoms[1]
124+
function _to_term(reference, term::MP.AbstractTermLike, dict, syms)
125+
coef = MP.coefficient(term)
126+
mono = _to_term(reference, MP.monomial(term), dict, syms)
127+
if isone(coef)
128+
return mono
124129
else
125-
t = similarterm(reference,
126-
+,
127-
map((x,y)->isone(y) ? x : Int(y)*x,
128-
monoms, x.coeffs[1:length(monoms)]),
129-
symtype(reference))
130+
return MP.coefficient(term) * mono
130131
end
132+
end
131133

132-
substitute(t, dict, fold=false)
134+
function _to_term(reference, x::MP.AbstractPolynomialLike, dict, syms)
135+
if MP.nterms(x) == 0
136+
return 0
137+
elseif MP.nterms(x) == 1
138+
return _to_term(reference, first(MP.terms(x)), dict, syms)
139+
else
140+
terms = map(MP.terms(x)) do term
141+
_to_term(reference, term, dict, syms)
142+
end
143+
return similarterm(reference, +, terms, symtype(reference))
144+
end
133145
end
134146

135147
function _to_term(reference, x, dict, vars)
136148
if istree(x)
137-
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)), symtype(x))
149+
t = similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)), symtype(x))
138150
else
139151
if haskey(dict, x)
140152
return dict[x]
@@ -144,18 +156,21 @@ function _to_term(reference, x, dict, vars)
144156
end
145157
end
146158

147-
<(a::MPoly, b::MPoly) = false
159+
<(a::MP.AbstractPolynomialLike, b::MP.AbstractPolynomialLike) = false
148160

149161
"""
150-
expand(expr)
162+
expand(expr, variable_type::Type=DynamicPolynomials.PolyVar{true})
151163
152-
Expand expressions by distributing multiplication over addition.
164+
Expand expressions by distributing multiplication over addition, e.g.,
165+
`a*(b+c)` becomes `ab+ac`.
153166
154-
`a*(b+c)` becomes `ab+ac`. `expand` uses [AbstractAlgebra.jl](https://nemocas.github.io/AbstractAlgebra.jl/latest/) to construct
155-
dense Multi-variate polynomial to do this very fast.
167+
`expand` uses replace symbols and non-algebraic expressions by variables of type
168+
`variable_type` to compute the distribution using a specialized sparse
169+
multivariate polynomials implementation.
170+
`variable_type` can be any subtype of `MultivariatePolynomials.AbstractVariable`.
156171
"""
157-
function expand(x)
158-
to_term(x, to_mpoly(x)...)
172+
function expand(expr, variable_type::Type=DynamicPolynomials.PolyVar{true})
173+
to_term(expr, to_mpoly(expr, variable_type)...)
159174
end
160175

161176
Base.@deprecate polynormalize(x) expand(x)

src/types.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ end
494494
function show_mul(io, args)
495495
length(args) == 1 && return print_arg(io, *, args[1])
496496

497-
paren_scalar = args[1] isa Complex || args[1] isa Rational
497+
paren_scalar = args[1] isa Complex || args[1] isa Rational || (args[1] isa Number && !isfinite(args[1]))
498498
minus = args[1] isa Number && args[1] == -1
499499
unit = args[1] isa Number && args[1] == 1
500500
nostar = !paren_scalar && args[1] isa Number && !(args[2] isa Number)
@@ -867,7 +867,10 @@ istree(a::Pow) = true
867867

868868
operation(a::Pow) = ^
869869

870-
arguments(a::Pow) = [a.base, a.exp]
870+
# Use `Union` to avoid promoting the base and exponent to the same type.
871+
# For instance, if `a.base` is a multivariate polynomial and `a.exp` is a number,
872+
# we don't want to promote `a.exp` to a multivariate polynomial.
873+
arguments(a::Pow) = Union{typeof(a.base), typeof(a.exp)}[a.base, a.exp]
871874

872875
Base.hash(p::Pow, u::UInt) = hash(p.exp, hash(p.base, u))
873876

0 commit comments

Comments
 (0)