Skip to content

Commit b6a7675

Browse files
authored
Merge pull request #103 from JuliaSymbolics/s/mpoly
Convert to and from MPoly
2 parents 2bd03bc + bcf6da0 commit b6a7675

File tree

12 files changed

+186
-22
lines changed

12 files changed

+186
-22
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ authors = ["Shashi Gowda"]
44
version = "0.3.3"
55

66
[deps]
7+
AbstractAlgebra = "c3fe647b-3220-5bb0-a1ea-a7954cac585d"
78
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
9+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
810
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1012
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

benchmark/benchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
5252
ex = random_term(1000, atoms=[a, b, c, d, a^(-1), b^(-1), 1, 2.0], funs=[+, *])
5353

5454
overhead["simplify_no_fixp"]["randterm:serial"] = @benchmarkable simplify($ex, threaded=false, fixpoint=false)
55-
overhead["simplify_no_fixp"]["randterm:thread"] = @benchmarkable simplify($ex, threaded=true, fixpoint=false)
55+
overhead["simplify_no_fixp"]["randterm:thread"] = @benchmarkable simplify($ex, threaded=true, fixpoint=false)
5656
end

src/SymbolicUtils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ using Combinatorics: permutations
4444
export @rule, @acrule, RuleSet
4545
include("rule_dsl.jl")
4646

47-
export simplify, substitute
47+
import AbstractAlgebra.Generic: MPoly, PolynomialRing, ZZ, exponent_vector
48+
using AbstractAlgebra: ismonomial
49+
using DataStructures
50+
include("abstractalgebra.jl")
4851

52+
export simplify, substitute
4953
include("simplify.jl")
5054

5155
include("rulesets.jl")

src/abstractalgebra.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Polynomial Normal Form
2+
3+
"""
4+
labels!(dict, t)
5+
6+
Find all terms that are not + and * and replace them
7+
with a symbol, store the symbol => term mapping in `dict`.
8+
"""
9+
function labels! end
10+
11+
# Turn a Term into a multivariate polynomial
12+
function labels!(dicts, t::Sym)
13+
sym2term, term2sym = dicts
14+
if !haskey(term2sym, t)
15+
sym2term[t] = t
16+
term2sym[t] = t
17+
end
18+
return t
19+
end
20+
21+
function labels!(dicts, t)
22+
if t isa Integer
23+
return t
24+
elseif t isa Term && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
25+
tt = arguments(t)
26+
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
27+
elseif t isa Term && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
28+
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
29+
else
30+
sym2term, term2sym = dicts
31+
if haskey(term2sym, t)
32+
return term2sym[t]
33+
end
34+
if t isa Term
35+
tt = arguments(t)
36+
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
37+
sym2term[sym] = Term{symtype(t)}(operation(t),
38+
map(x->to_mpoly(x, dicts)[1], arguments(t)))
39+
else
40+
sym = Sym{symtype(t)}(gensym("literal"))
41+
sym2term[sym] = t
42+
end
43+
44+
term2sym[t] = sym
45+
46+
return sym
47+
end
48+
end
49+
50+
ismpoly(x) = x isa MPoly || x isa Integer
51+
isnonnegint(x) = x isa Integer && x >= 0
52+
53+
const mpoly_rules = RuleSet([@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
54+
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
55+
@rule(+(~x) => ~x)
56+
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
57+
@rule(*(~x) => ~x)
58+
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))])
59+
function to_mpoly(t, dicts=(OrderedDict{Sym, Any}(), OrderedDict{Any, Sym}()))
60+
# term2sym is only used to assign the same
61+
# symbol for the same term -- in other words,
62+
# it does common subexpression elimination
63+
64+
sym2term, term2sym = dicts
65+
labeled = labels!((sym2term, term2sym), t)
66+
67+
if isempty(sym2term)
68+
return labeled, []
69+
end
70+
71+
ks = sort(collect(keys(sym2term)), lt=<ₑ)
72+
R, vars = PolynomialRing(ZZ, String.(nameof.(ks)))
73+
74+
replace_with_poly = Dict{Sym,MPoly}(zip(ks, vars))
75+
t_poly = substitute(labeled, replace_with_poly, fold=false)
76+
simplify(t_poly, EmptyCtx(), rules=mpoly_rules),
77+
sym2term,
78+
reverse(ks)
79+
end
80+
81+
function to_term(x, dict, syms)
82+
dict = copy(dict)
83+
for (k, v) in dict
84+
dict[k] = _to_term(v, dict, syms)
85+
end
86+
_to_term(x, dict, syms)
87+
end
88+
89+
function _to_term(x::MPoly, dict, syms)
90+
91+
function mul_coeffs(exps)
92+
monics = [e == 1 ? syms[i] : syms[i]^e for (i, e) in enumerate(reverse(exps)) if !iszero(e)]
93+
if length(monics) == 1
94+
return monics[1]
95+
elseif length(monics) == 0
96+
return 1
97+
else
98+
return Term(*, monics)
99+
end
100+
end
101+
102+
monoms = [mul_coeffs(exponent_vector(x, i)) for i in 1:x.length]
103+
if length(monoms) == 0
104+
return 0
105+
end
106+
if length(monoms) == 1
107+
t = !isone(x.coeffs[1]) ? monoms[1] * x.coeffs[1] : monoms[1]
108+
else
109+
t = Term(+, map((x,y)->isone(y) ? x : y*x, monoms, x.coeffs[1:length(monoms)]))
110+
end
111+
112+
substitute(t, dict, fold=false)
113+
end
114+
115+
function _to_term(x, dict, vars)
116+
if haskey(dict, x)
117+
return dict[x]
118+
else
119+
return x
120+
end
121+
end
122+
123+
function _to_term(x::Term, dict, vars)
124+
t=Term{symtype(x)}(operation(x), _to_term.(arguments(x), (dict,), (vars,)))
125+
end
126+
127+
<(a::MPoly, b::MPoly) = false

src/rule_dsl.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ getdepth(r::Rule) = r.depth
1616

1717
function rule_depth(rule, d=0, maxdepth=0)
1818
if rule isa Term
19-
maxdepth = maximum(rule_depth(r, d+1, maxdepth) for r in arguments(rule))
19+
maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1)
2020
elseif rule isa Slot || rule isa Segment
2121
maxdepth = max(d, maxdepth)
2222
end
@@ -28,9 +28,10 @@ function Base.show(io::IO, r::Rule)
2828
end
2929

3030
const EMPTY_DICT = ImmutableDict{Symbol, Any}(:____, nothing)
31+
struct DefaultCtx end
3132
struct EmptyCtx end
3233

33-
function (r::Rule)(term, ctx=EmptyCtx())
34+
function (r::Rule)(term, ctx=DefaultCtx())
3435
rhs = r.rhs
3536

3637
r.matcher((term,), EMPTY_DICT, ctx) do bindings, n
@@ -192,7 +193,7 @@ end
192193

193194
Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")")
194195

195-
function (acr::ACRule)(term, ctx=EmptyCtx())
196+
function (acr::ACRule)(term, ctx=DefaultCtx())
196197
r = Rule(acr)
197198
if !(term isa Term)
198199
r(term)
@@ -220,7 +221,7 @@ end
220221
#### Rulesets
221222

222223
"""
223-
RuleSet(rules::Vector{AbstractRules}, context=EmptyCtx())(expr; depth=typemax(Int), applyall=false, recurse=true)
224+
RuleSet(rules::Vector{AbstractRules}, context=DefaultCtx())(expr; depth=typemax(Int), applyall=false, recurse=true)
224225
225226
`RuleSet` is an `AbstractRule` which applies the given `rules` throughout an `expr` with the
226227
context `context`.
@@ -267,7 +268,7 @@ end
267268

268269
const rule_repr = IdDict()
269270

270-
function (r::RuleSet)(term, context=EmptyCtx();
271+
function (r::RuleSet)(term, context=DefaultCtx();
271272
depth=typemax(Int),
272273
applyall::Bool=false,
273274
recurse::Bool=true,

src/rulesets.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ const PLUS_RULES = RuleSet([
2323
@acrule(~a::isnumber + ~b::isnumber => ~a + ~b)
2424

2525
@acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))
26-
@acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))
27-
@acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))
26+
@acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))
27+
@acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))
2828

2929
@acrule(~x + *(~β, ~x) => *(1 + ~β, ~x))
3030
@acrule(*(~α::isnumber, ~x) + ~x => *(~α + 1, ~x))

src/simplify.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@ of symtype Number.
1111
"""
1212
default_rules(x, ctx) = SIMPLIFY_RULES
1313

14-
function default_rules(x, ctx::EmptyCtx)
14+
function default_rules(x, ctx::DefaultCtx)
1515
has_trig(x) ?
1616
SIMPLIFY_RULES_TRIG :
1717
SIMPLIFY_RULES
1818
end
1919

20+
function default_rules(x, ctx::EmptyCtx)
21+
identity
22+
end
23+
2024
"""
21-
simplify(x, ctx=EmptyCtx();
25+
simplify(x, ctx=DefaultCtx();
2226
rules=default_rules(x, ctx),
2327
fixpoint=true,
2428
applyall=true,
@@ -28,7 +32,7 @@ Simplify an expression by applying `rules` until there are no changes.
2832
The second argument, the context is passed to every [`Contextual`](#Contextual)
2933
predicate and can be accessed as `(@ctx)` in the right hand side of `@rule` expression.
3034
31-
By default the context is an `EmptyCtx()` -- which means there is no contextual information.
35+
By default the context is an `DefaultCtx()` -- which means there is no contextual information.
3236
Any arbitrary type can be used as a context, and packages defining their own contexts
3337
should define `default_rules(ctx::TheContextType)` to return a `RuleSet` that will
3438
be used by default while simplifying under that context.
@@ -39,7 +43,15 @@ Applies them once if `fixpoint=false`.
3943
The `applyall` and `recurse` keywords are forwarded to the enclosed
4044
`RuleSet`, they are mainly used for internal optimization.
4145
"""
42-
function simplify(x, ctx=EmptyCtx(); rules=default_rules(x, ctx), fixpoint=true, applyall=true, kwargs...)
46+
function simplify(x, ctx=DefaultCtx();
47+
rules=default_rules(x, ctx),
48+
fixpoint=true,
49+
applyall=true,
50+
mpoly=false,
51+
kwargs...)
52+
if mpoly
53+
x = to_term(to_mpoly(x)...)
54+
end
4355
if fixpoint
4456
SymbolicUtils.fixpoint(rules, x, ctx; applyall=applyall)
4557
else
@@ -56,8 +68,13 @@ Base.@deprecate simplify(x, rules::RuleSet; kwargs...) simplify(x, rules=rules;
5668
substitute any subexpression that matches a key in `dict` with
5769
the corresponding value.
5870
"""
59-
function substitute(expr, dict)
60-
RuleSet([@rule ~x::(x->haskey(dict, x)) => dict[~x]])(expr) |> fold
71+
function substitute(expr, dict; fold=true)
72+
rs = RuleSet([@rule ~x::(x->haskey(dict, x)) => dict[~x]])
73+
if fold
74+
rs(expr) |> SymbolicUtils.fold
75+
else
76+
rs(expr)
77+
end
6178
end
6279

6380
fold(x) = x

test/basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicUtils: Sym, FnType, Term, symtype, Contextual, EmptyCtx
1+
using SymbolicUtils: Sym, FnType, Term, symtype, Contextual, DefaultCtx
22
using SymbolicUtils
33
using Test
44

@@ -87,7 +87,7 @@ end
8787
@testset "Contexts" begin
8888
@syms a b c
8989

90-
@test @rule(~x::Contextual((x, ctx) -> ctx==EmptyCtx()) => "yes")(1) == "yes"
90+
@test @rule(~x::Contextual((x, ctx) -> ctx==DefaultCtx()) => "yes")(1) == "yes"
9191
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => true)(a, Dict(a=>1))
9292
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => true)(b, Dict(a=>1)) === nothing
9393
@test_throws UndefVarError @rule(~x => __CTX__)(a, "test")

test/fuzzlib.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ function rand_input(T)
77
if T == Bool
88
return rand(Bool)
99
elseif T <: Integer
10-
return rand(-100:100)
10+
x = rand(-100:100)
11+
while iszero(x)
12+
x = rand(-100:100)
13+
end
14+
return x
1115
elseif T == Rational
1216
return Rational(rand_input(Int), rand(1:50)) # no 0 denominator tests yet!
1317
elseif T == Real

test/nf.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using SymbolicUtils, Test
2+
@testset "polyform" begin
3+
@syms a b c d
4+
@test simplify(a * (b + -1 * c) + -1 * (b * a + -1 * c * a), mpoly=true) == 0
5+
@eqtest simplify(sin((a+b)^2)^2; mpoly=true) == sin(a^2+b^2+2*a*b)^2
6+
# fixme: can this be made faster?
7+
@test simplify(sin((a+b)^2)^2 + cos((a+b)^2)^2; mpoly=true) == 1
8+
end

0 commit comments

Comments
 (0)