Skip to content

Commit 3f6c8ca

Browse files
shashiYingboMa
andcommitted
treat all other operations or literals as new symbols
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 2dedd35 commit 3f6c8ca

File tree

2 files changed

+59
-34
lines changed

2 files changed

+59
-34
lines changed

src/abstractalgebra.jl

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
using DataStructures
2+
13
"""
2-
labels(dict, t)
4+
labels!(dict, t)
35
46
Find all terms that are not + and * and replace them
57
with a symbol, store the symbol => term mapping in `dict`.
68
"""
7-
function labels end
9+
function labels! end
810

911
# Turn a Term into a multivariate polynomial
10-
labels(dicts, t; label_terms=false) = t
11-
function labels(dicts, t::Sym; label_terms=false)
12+
function labels!(dicts, t::Sym)
1213
sym2term, term2sym = dicts
1314
if !haskey(term2sym, t)
1415
sym2term[t] = t
@@ -17,56 +18,71 @@ function labels(dicts, t::Sym; label_terms=false)
1718
return t
1819
end
1920

20-
function labels(dicts, t::Term; label_terms=false)
21-
tt = arguments(t)
22-
if operation(t) == (*) || operation(t) == (+)
23-
return Term{symtype(t)}(operation(t), map(x->labels(dicts, x;
24-
label_terms=label_terms), tt))
21+
function labels!(dicts, t)
22+
if t isa Term && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
23+
tt = arguments(t)
24+
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), tt))
25+
elseif t isa Integer
26+
return t
2527
else
2628
sym2term, term2sym = dicts
2729
if haskey(term2sym, t)
2830
return term2sym[t]
2931
end
30-
if label_terms
32+
if t isa Term
33+
tt = arguments(t)
3134
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
32-
sym2term[sym] = Term{symtype(t)}(operation(t), map(x->labels(dicts, x;
33-
label_terms=label_terms),
34-
tt))
35-
x = term2sym[t] = sym
36-
37-
return x
35+
sym2term[sym] = Term{symtype(t)}(operation(t),
36+
map(x->to_mpoly(x, dicts)[1], tt))
3837
else
39-
return Term{symtype(t)}(operation(t), map(x->labels(dicts, x; label_terms=label_terms), tt))
38+
sym = Sym{symtype(t)}(gensym("literal"))
39+
sym2term[sym] = t
4040
end
41+
42+
term2sym[t] = sym
43+
44+
return sym
4145
end
4246
end
4347

4448
ismpoly(x) = x isa MPoly || x isa Integer
4549
isnonnegint(x) = x isa Integer && x >= 0
4650

47-
function to_mpoly(t)
48-
sym2term, term2sym = Dict(), Dict()
49-
ls = labels((sym2term, term2sym), t)
51+
function to_mpoly(t, dicts=(OrderedDict(), OrderedDict()))
52+
# term2sym is only used to assign the same
53+
# symbol for the same term -- in other words,
54+
# it does common subexpression elimination
55+
56+
sym2term, term2sym = dicts
57+
labeled = labels!((sym2term, term2sym), t)
5058

5159
if isempty(sym2term)
52-
return t, []
60+
return labeled, []
5361
end
5462

5563
ks = collect(keys(sym2term))
5664
R, vars = PolynomialRing(ZZ, String.(nameof.(ks)))
5765

58-
t_poly_1 = substitute(t, term2sym, fold=false)
59-
t_poly_2 = substitute(t_poly_1, Dict(ks .=> vars), fold=false)
66+
t_poly = substitute(labeled, Dict(ks .=> vars), fold=false)
6067
rs = RuleSet([@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
6168
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
6269
@rule(+(~x) => ~x)
6370
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
6471
@rule(*(~x) => ~x)
6572
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))])
66-
simplify(t_poly_2, EmptyCtx(), rules=rs), Dict(Pair.(1:length(vars), ks))
73+
simplify(t_poly, EmptyCtx(), rules=rs), sym2term, Dict(Pair.(1:length(vars), ks))
6774
end
6875

69-
function to_term(x::MPoly, syms)
76+
function to_term(x::MPoly, dict, syms)
77+
dict = copy(dict)
78+
for (k, v) in dict
79+
dict[k] = _to_term(v, dict, syms)
80+
end
81+
_to_term(x, dict, syms)
82+
end
83+
84+
function _to_term(x::MPoly, dict, syms)
85+
7086
function mul_coeffs(exps)
7187
monics = [e == 1 ? syms[i] : syms[i]^e for (i, e) in enumerate(exps) if !iszero(e)]
7288
if length(monics) == 1
@@ -77,20 +93,24 @@ function to_term(x::MPoly, syms)
7793
return Term(*, monics)
7894
end
7995
end
96+
8097
monoms = [mul_coeffs(exponent_vector(x, i)) for i in 1:x.length]
81-
if length(monoms) == 1
82-
!isone(x.coeffs[1]) ? monoms[1] * x.coeffs[1] : monoms[1]
83-
elseif length(monoms) == 0
98+
if length(monoms) == 0
8499
return 0
100+
end
101+
if length(monoms) == 1
102+
t = !isone(x.coeffs[1]) ? monoms[1] * x.coeffs[1] : monoms[1]
85103
else
86-
Term(+, map((x,y)->isone(y) ? x : y*x, monoms, x.coeffs[1:length(monoms)]))
104+
t = Term(+, map((x,y)->isone(y) ? x : y*x, monoms, x.coeffs[1:length(monoms)]))
87105
end
106+
107+
substitute(t, dict, fold=false)
88108
end
89109

90-
to_term(x, vars) = x
110+
_to_term(x, dict, vars) = x
91111

92-
function to_term(x::Term, vars)
93-
Term{symtype(x)}(operation(x), to_term.(arguments(x), (vars,)))
112+
function _to_term(x::Term, dict, vars)
113+
t=Term{symtype(x)}(operation(x), _to_term.(arguments(x), (dict,), (vars,)))
94114
end
95115

96116
<(a::MPoly, b::MPoly) = false

src/simplify.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ Applies them once if `fixpoint=false`.
4343
The `applyall` and `recurse` keywords are forwarded to the enclosed
4444
`RuleSet`, they are mainly used for internal optimization.
4545
"""
46-
function simplify(x, ctx=DefaultCtx(); rules=default_rules(x, ctx), fixpoint=true, applyall=true, kwargs...)
47-
if ctx isa DefaultCtx
46+
function simplify(x, ctx=DefaultCtx();
47+
rules=default_rules(x, ctx),
48+
fixpoint=true,
49+
applyall=true,
50+
mpoly=false,
51+
kwargs...)
52+
if mpoly
4853
x = to_term(to_mpoly(x)...)
4954
end
5055
if fixpoint

0 commit comments

Comments
 (0)