Skip to content

Commit f987657

Browse files
committed
fix stuff
1 parent 3f6c8ca commit f987657

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

src/abstractalgebra.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using DataStructures
22

3+
# Polynomial Normal Form
4+
35
"""
46
labels!(dict, t)
57
@@ -19,11 +21,13 @@ function labels!(dicts, t::Sym)
1921
end
2022

2123
function labels!(dicts, t)
22-
if t isa Term && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
23-
tt = arguments(t)
24-
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), tt))
25-
elseif t isa Integer
24+
if t isa Integer
2625
return t
26+
elseif t isa Term && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
27+
tt = arguments(t)
28+
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
29+
elseif t isa Term && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
30+
return Term{symtype(t)}(operation(t), map(x->labels!(dicts, x), arguments(t)))
2731
else
2832
sym2term, term2sym = dicts
2933
if haskey(term2sym, t)
@@ -33,7 +37,7 @@ function labels!(dicts, t)
3337
tt = arguments(t)
3438
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
3539
sym2term[sym] = Term{symtype(t)}(operation(t),
36-
map(x->to_mpoly(x, dicts)[1], tt))
40+
map(x->to_mpoly(x, dicts)[1], arguments(t)))
3741
else
3842
sym = Sym{symtype(t)}(gensym("literal"))
3943
sym2term[sym] = t
@@ -48,7 +52,13 @@ end
4852
ismpoly(x) = x isa MPoly || x isa Integer
4953
isnonnegint(x) = x isa Integer && x >= 0
5054

51-
function to_mpoly(t, dicts=(OrderedDict(), OrderedDict()))
55+
const mpoly_rules = RuleSet([@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
56+
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
57+
@rule(+(~x) => ~x)
58+
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
59+
@rule(*(~x) => ~x)
60+
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))])
61+
function to_mpoly(t, dicts=(OrderedDict{Sym, Any}(), OrderedDict{Any, Sym}()))
5262
# term2sym is only used to assign the same
5363
# symbol for the same term -- in other words,
5464
# it does common subexpression elimination
@@ -60,20 +70,17 @@ function to_mpoly(t, dicts=(OrderedDict(), OrderedDict()))
6070
return labeled, []
6171
end
6272

63-
ks = collect(keys(sym2term))
73+
ks = sort(collect(keys(sym2term)), lt=<)
6474
R, vars = PolynomialRing(ZZ, String.(nameof.(ks)))
6575

66-
t_poly = substitute(labeled, Dict(ks .=> vars), fold=false)
67-
rs = RuleSet([@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
68-
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
69-
@rule(+(~x) => ~x)
70-
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
71-
@rule(*(~x) => ~x)
72-
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))])
73-
simplify(t_poly, EmptyCtx(), rules=rs), sym2term, Dict(Pair.(1:length(vars), ks))
76+
replace_with_poly = Dict{Sym,MPoly}(zip(ks, vars))
77+
t_poly = substitute(labeled, replace_with_poly, fold=false)
78+
simplify(t_poly, EmptyCtx(), rules=mpoly_rules),
79+
sym2term,
80+
Dict(Pair.(length(vars):-1:1, ks))
7481
end
7582

76-
function to_term(x::MPoly, dict, syms)
83+
function to_term(x, dict, syms)
7784
dict = copy(dict)
7885
for (k, v) in dict
7986
dict[k] = _to_term(v, dict, syms)
@@ -84,7 +91,7 @@ end
8491
function _to_term(x::MPoly, dict, syms)
8592

8693
function mul_coeffs(exps)
87-
monics = [e == 1 ? syms[i] : syms[i]^e for (i, e) in enumerate(exps) if !iszero(e)]
94+
monics = [e == 1 ? syms[i] : syms[i]^e for (i, e) in enumerate(reverse(exps)) if !iszero(e)]
8895
if length(monics) == 1
8996
return monics[1]
9097
elseif length(monics) == 0
@@ -107,7 +114,13 @@ function _to_term(x::MPoly, dict, syms)
107114
substitute(t, dict, fold=false)
108115
end
109116

110-
_to_term(x, dict, vars) = x
117+
function _to_term(x, dict, vars)
118+
if haskey(dict, x)
119+
return dict[x]
120+
else
121+
return x
122+
end
123+
end
111124

112125
function _to_term(x::Term, dict, vars)
113126
t=Term{symtype(x)}(operation(x), _to_term.(arguments(x), (dict,), (vars,)))

0 commit comments

Comments
 (0)