1+ using DataStructures
2+
13"""
2- labels(dict, t)
4+ labels! (dict, t)
35
46Find all terms that are not + and * and replace them
57with a symbol, store the symbol => term mapping in `dict`.
68"""
7- function labels end
9+ function labels! end
810
911# Turn a Term into a multivariate polynomial
10- labels (dicts, t; label_terms= false ) = t
11- function labels (dicts, t:: Sym ; label_terms= false )
12+ function labels! (dicts, t:: Sym )
1213 sym2term, term2sym = dicts
1314 if ! haskey (term2sym, t)
1415 sym2term[t] = t
@@ -17,56 +18,71 @@ function labels(dicts, t::Sym; label_terms=false)
1718 return t
1819end
1920
20- function labels (dicts, t:: Term ; label_terms= false )
21- tt = arguments (t)
22- if operation (t) == (* ) || operation (t) == (+ )
23- return Term {symtype(t)} (operation (t), map (x-> labels (dicts, x;
24- label_terms= label_terms), tt))
21+ function labels! (dicts, t)
22+ if t isa Term && (operation (t) == (* ) || operation (t) == (+ ) || operation (t) == (- ))
23+ tt = arguments (t)
24+ return Term {symtype(t)} (operation (t), map (x-> labels! (dicts, x), tt))
25+ elseif t isa Integer
26+ return t
2527 else
2628 sym2term, term2sym = dicts
2729 if haskey (term2sym, t)
2830 return term2sym[t]
2931 end
30- if label_terms
32+ if t isa Term
33+ tt = arguments (t)
3134 sym = Sym {symtype(t)} (gensym (nameof (operation (t))))
32- sym2term[sym] = Term {symtype(t)} (operation (t), map (x-> labels (dicts, x;
33- label_terms= label_terms),
34- tt))
35- x = term2sym[t] = sym
36-
37- return x
35+ sym2term[sym] = Term {symtype(t)} (operation (t),
36+ map (x-> to_mpoly (x, dicts)[1 ], tt))
3837 else
39- return Term {symtype(t)} (operation (t), map (x-> labels (dicts, x; label_terms= label_terms), tt))
38+ sym = Sym {symtype(t)} (gensym (" literal" ))
39+ sym2term[sym] = t
4040 end
41+
42+ term2sym[t] = sym
43+
44+ return sym
4145 end
4246end
4347
4448ismpoly (x) = x isa MPoly || x isa Integer
4549isnonnegint (x) = x isa Integer && x >= 0
4650
47- function to_mpoly (t)
48- sym2term, term2sym = Dict (), Dict ()
49- ls = labels ((sym2term, term2sym), t)
51+ function to_mpoly (t, dicts= (OrderedDict (), OrderedDict ()))
52+ # term2sym is only used to assign the same
53+ # symbol for the same term -- in other words,
54+ # it does common subexpression elimination
55+
56+ sym2term, term2sym = dicts
57+ labeled = labels! ((sym2term, term2sym), t)
5058
5159 if isempty (sym2term)
52- return t , []
60+ return labeled , []
5361 end
5462
5563 ks = collect (keys (sym2term))
5664 R, vars = PolynomialRing (ZZ, String .(nameof .(ks)))
5765
58- t_poly_1 = substitute (t, term2sym, fold= false )
59- t_poly_2 = substitute (t_poly_1, Dict (ks .=> vars), fold= false )
66+ t_poly = substitute (labeled, Dict (ks .=> vars), fold= false )
6067 rs = RuleSet ([@rule (~ x:: ismpoly - ~ y:: ismpoly => ~ x + - 1 * (~ y))
6168 @acrule (~ x:: ismpoly + ~ y:: ismpoly => ~ x + ~ y)
6269 @rule (+ (~ x) => ~ x)
6370 @acrule (~ x:: ismpoly * ~ y:: ismpoly => ~ x * ~ y)
6471 @rule (* (~ x) => ~ x)
6572 @rule ((~ x:: ismpoly )^ (~ a:: isnonnegint ) => (~ x)^ (~ a))])
66- simplify (t_poly_2 , EmptyCtx (), rules= rs), Dict (Pair .(1 : length (vars), ks))
73+ simplify (t_poly , EmptyCtx (), rules= rs), sym2term , Dict (Pair .(1 : length (vars), ks))
6774end
6875
69- function to_term (x:: MPoly , syms)
76+ function to_term (x:: MPoly , dict, syms)
77+ dict = copy (dict)
78+ for (k, v) in dict
79+ dict[k] = _to_term (v, dict, syms)
80+ end
81+ _to_term (x, dict, syms)
82+ end
83+
84+ function _to_term (x:: MPoly , dict, syms)
85+
7086 function mul_coeffs (exps)
7187 monics = [e == 1 ? syms[i] : syms[i]^ e for (i, e) in enumerate (exps) if ! iszero (e)]
7288 if length (monics) == 1
@@ -77,20 +93,24 @@ function to_term(x::MPoly, syms)
7793 return Term (* , monics)
7894 end
7995 end
96+
8097 monoms = [mul_coeffs (exponent_vector (x, i)) for i in 1 : x. length]
81- if length (monoms) == 1
82- ! isone (x. coeffs[1 ]) ? monoms[1 ] * x. coeffs[1 ] : monoms[1 ]
83- elseif length (monoms) == 0
98+ if length (monoms) == 0
8499 return 0
100+ end
101+ if length (monoms) == 1
102+ t = ! isone (x. coeffs[1 ]) ? monoms[1 ] * x. coeffs[1 ] : monoms[1 ]
85103 else
86- Term (+ , map ((x,y)-> isone (y) ? x : y* x, monoms, x. coeffs[1 : length (monoms)]))
104+ t = Term (+ , map ((x,y)-> isone (y) ? x : y* x, monoms, x. coeffs[1 : length (monoms)]))
87105 end
106+
107+ substitute (t, dict, fold= false )
88108end
89109
90- to_term (x , vars) = x
110+ _to_term (x, dict , vars) = x
91111
92- function to_term (x:: Term , vars)
93- Term {symtype(x)} (operation (x), to_term .(arguments (x), (vars,)))
112+ function _to_term (x:: Term , dict , vars)
113+ t = Term {symtype(x)} (operation (x), _to_term .(arguments (x), (dict, ), (vars,)))
94114end
95115
96116< ₑ (a:: MPoly , b:: MPoly ) = false
0 commit comments