Skip to content

Commit 1b31a67

Browse files
committed
Fix mpoly.deflation, fix invariant tests
1 parent 2e0ebc7 commit 1b31a67

File tree

6 files changed

+134
-86
lines changed

6 files changed

+134
-86
lines changed

src/flint/test/test_all.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3241,21 +3241,26 @@ def quick_poly():
32413241

32423242
f1 = 3*x0**2*x1**2 + 6*x0*x1**2 + 9*x1**2
32433243
res, stride = f1.deflation()
3244-
assert res == 3*x0**2 + 6*x0 + 9
3245-
assert tuple(stride) == (1, 0)
3244+
assert res == 3*x0**2*x1 + 6*x0*x1 + 9*x1
3245+
assert tuple(stride) == (1, 2)
32463246

32473247
g1 = ((x0**2 + x1**2)**3 + (x0**2 + x1**2)**2 + 1)
32483248
res, stride = g1.deflation()
32493249
assert res == x0**3 + 3*x0**2*x1 + x0**2 + 3*x0*x1**2 + 2*x0*x1 + x1**3 + x1**2 + 1
32503250
assert tuple(stride) == (2, 2)
32513251

32523252
for p in [f1, g1]:
3253-
n, m = p.deflation_monom()
3254-
assert m * p.deflate(n).inflate(n) == p
3253+
pd, n = p.deflation()
3254+
assert pd.inflate(n) == p
3255+
assert p.deflate(n).inflate(n) == p
32553256

3256-
n, i = p.deflation_index()
3257-
m = ctx.term(exp_vec=i)
3258-
assert p.deflate(n).inflate(n) * m == p
3257+
pd, n, m = p.deflation_monom()
3258+
assert m * pd.inflate(n) == p
3259+
3260+
if not composite_characteristic:
3261+
n, i = p.deflation_index()
3262+
m = ctx.term(exp_vec=i)
3263+
assert (p / m).deflate(n).inflate(n) * m == p
32593264

32603265
if P is flint.fmpz_mpoly:
32613266
assert (x0**2 * x1 + x0 * x1).primitive() == (1, x0**2*x1 + x0*x1)

src/flint/types/fmpq_mpoly.pyx

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,25 +1058,31 @@ cdef class fmpq_mpoly(flint_mpoly):
10581058
>>> from flint import Ordering
10591059
>>> ctx = fmpq_mpoly_ctx.get_context(2, Ordering.lex, nametup=('x', 'y'))
10601060
>>> x, y = ctx.gens()
1061-
>>> f = x**3 * y + x * y**4 + x * y
1061+
>>> f = x**2 * y**2 + x * y**2
10621062
>>> q, N = f.deflation()
10631063
>>> q, N
1064-
(x + y + 1, [2, 3])
1064+
(x^2*y + x*y, [1, 2])
1065+
>>> q.inflate(N) == f
1066+
True
10651067
"""
10661068
cdef:
1067-
fmpz_vec _shift = fmpz_vec(self.ctx.nvars())
1068-
fmpz_vec stride = fmpz_vec(self.ctx.nvars())
1069+
slong nvars = self.ctx.nvars()
1070+
fmpz_vec shift = fmpz_vec(nvars)
1071+
fmpz_vec stride = fmpz_vec(nvars)
10691072
fmpq_mpoly res = create_fmpq_mpoly(self.ctx)
10701073

1071-
fmpz_mpoly_deflation(_shift.val, stride.val, self.val.zpoly, self.ctx.val.zctx)
1074+
fmpz_mpoly_deflation(shift.val, stride.val, self.val.zpoly, self.ctx.val.zctx)
1075+
1076+
for i in range(nvars):
1077+
stride[i] = shift[i].gcd(stride[i])
1078+
shift[i] = 0
10721079

1073-
cdef fmpz_vec zero_shift = fmpz_vec(self.ctx.nvars())
1074-
fmpz_mpoly_deflate(res.val.zpoly, self.val.zpoly, zero_shift.val, stride.val, self.ctx.val.zctx)
1080+
fmpz_mpoly_deflate(res.val.zpoly, self.val.zpoly, shift.val, stride.val, self.ctx.val.zctx)
10751081
fmpq_set(res.val.content, self.val.content)
10761082

10771083
return res, list(stride)
10781084

1079-
def deflation_monom(self) -> tuple[list[int], fmpq_mpoly]:
1085+
def deflation_monom(self) -> tuple[fmpq_mpoly, list[int], fmpq_mpoly]:
10801086
"""
10811087
Compute the exponent vector ``N`` and monomial ``m`` such that ``p(X^(1/N))
10821088
= m * q(X^N)`` for maximal N. Importantly the deflation itself is not computed
@@ -1086,21 +1092,25 @@ cdef class fmpq_mpoly(flint_mpoly):
10861092
>>> ctx = fmpq_mpoly_ctx.get_context(2, Ordering.lex, nametup=('x', 'y'))
10871093
>>> x, y = ctx.gens()
10881094
>>> f = x**3 * y + x * y**4 + x * y
1089-
>>> N, m = f.deflation_monom()
1090-
>>> N, m
1091-
([2, 3], x*y)
1092-
>>> f_deflated = f.deflate(N)
1093-
>>> f_deflated
1094-
x + y + 1
1095-
>>> m * f_deflated.inflate(N)
1095+
>>> fd, N, m = f.deflation_monom()
1096+
>>> fd, N, m
1097+
(x + y + 1, [2, 3], x*y)
1098+
>>> m * fd.inflate(N)
10961099
x^3*y + x*y^4 + x*y
10971100
"""
1098-
cdef fmpq_mpoly monom = create_fmpq_mpoly(self.ctx)
1101+
cdef:
1102+
slong nvars = self.ctx.nvars()
1103+
fmpq_mpoly res = create_fmpq_mpoly(self.ctx)
1104+
fmpq_mpoly monom = create_fmpq_mpoly(self.ctx)
1105+
fmpz_vec shift = fmpz_vec(nvars)
1106+
fmpz_vec stride = fmpz_vec(nvars)
10991107

1100-
stride, _shift = self.deflation_index()
1108+
fmpz_mpoly_deflation(shift.val, stride.val, self.val.zpoly, self.ctx.val.zctx)
1109+
fmpq_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(shift).val, self.ctx.val)
1110+
fmpz_mpoly_deflate(res.val.zpoly, self.val.zpoly, shift.val, stride.val, self.ctx.val.zctx)
1111+
fmpq_set(res.val.content, self.val.content)
11011112

1102-
fmpq_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(_shift).val, self.ctx.val)
1103-
return list(stride), monom
1113+
return res, list(stride), monom
11041114

11051115
def deflation_index(self) -> tuple[list[int], list[int]]:
11061116
"""

src/flint/types/fmpz_mod_mpoly.pyx

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,24 +1137,30 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
11371137
>>> from flint import Ordering
11381138
>>> ctx = fmpz_mod_mpoly_ctx.get_context(2, Ordering.lex, 11, nametup=('x', 'y'))
11391139
>>> x, y = ctx.gens()
1140-
>>> f = x**3 * y + x * y**4 + x * y
1140+
>>> f = x**2 * y**2 + x * y**2
11411141
>>> q, N = f.deflation()
11421142
>>> q, N
1143-
(x + y + 1, [2, 3])
1143+
(x^2*y + x*y, [1, 2])
1144+
>>> q.inflate(N) == f
1145+
True
11441146
"""
11451147
cdef:
1146-
fmpz_vec _shift = fmpz_vec(self.ctx.nvars())
1147-
fmpz_vec stride = fmpz_vec(self.ctx.nvars())
1148+
slong nvars = self.ctx.nvars()
1149+
fmpz_vec shift = fmpz_vec(nvars)
1150+
fmpz_vec stride = fmpz_vec(nvars)
11481151
fmpz_mod_mpoly res = create_fmpz_mod_mpoly(self.ctx)
11491152

1150-
fmpz_mod_mpoly_deflation(_shift.val, stride.val, self.val, self.ctx.val)
1153+
fmpz_mod_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1154+
1155+
for i in range(nvars):
1156+
stride[i] = shift[i].gcd(stride[i])
1157+
shift[i] = 0
11511158

1152-
cdef fmpz_vec zero_shift = fmpz_vec(self.ctx.nvars())
1153-
fmpz_mod_mpoly_deflate(res.val, self.val, zero_shift.val, stride.val, self.ctx.val)
1159+
fmpz_mod_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11541160

11551161
return res, list(stride)
11561162

1157-
def deflation_monom(self) -> tuple[list[int], fmpz_mod_mpoly]:
1163+
def deflation_monom(self) -> tuple[fmpz_mod_mpoly, list[int], fmpz_mod_mpoly]:
11581164
"""
11591165
Compute the exponent vector ``N`` and monomial ``m`` such that ``p(X^(1/N))
11601166
= m * q(X^N)`` for maximal N. Importantly the deflation itself is not computed
@@ -1164,21 +1170,24 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
11641170
>>> ctx = fmpz_mod_mpoly_ctx.get_context(2, Ordering.lex, 11, nametup=('x', 'y'))
11651171
>>> x, y = ctx.gens()
11661172
>>> f = x**3 * y + x * y**4 + x * y
1167-
>>> N, m = f.deflation_monom()
1168-
>>> N, m
1169-
([2, 3], x*y)
1170-
>>> f_deflated = f.deflate(N)
1171-
>>> f_deflated
1172-
x + y + 1
1173-
>>> m * f_deflated.inflate(N)
1173+
>>> fd, N, m = f.deflation_monom()
1174+
>>> fd, N, m
1175+
(x + y + 1, [2, 3], x*y)
1176+
>>> m * fd.inflate(N)
11741177
x^3*y + x*y^4 + x*y
11751178
"""
1176-
cdef fmpz_mod_mpoly monom = create_fmpz_mod_mpoly(self.ctx)
1179+
cdef:
1180+
slong nvars = self.ctx.nvars()
1181+
fmpz_mod_mpoly res = create_fmpz_mod_mpoly(self.ctx)
1182+
fmpz_mod_mpoly monom = create_fmpz_mod_mpoly(self.ctx)
1183+
fmpz_vec shift = fmpz_vec(nvars)
1184+
fmpz_vec stride = fmpz_vec(nvars)
11771185

1178-
stride, _shift = self.deflation_index()
1186+
fmpz_mod_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1187+
fmpz_mod_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(shift).val, self.ctx.val)
1188+
fmpz_mod_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11791189

1180-
fmpz_mod_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(_shift).val, self.ctx.val)
1181-
return list(stride), monom
1190+
return res, list(stride), monom
11821191

11831192
def deflation_index(self) -> tuple[list[int], list[int]]:
11841193
"""

src/flint/types/fmpz_mpoly.pyx

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,24 +1156,30 @@ cdef class fmpz_mpoly(flint_mpoly):
11561156
>>> from flint import Ordering
11571157
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, nametup=('x', 'y'))
11581158
>>> x, y = ctx.gens()
1159-
>>> f = x**3 * y + x * y**4 + x * y
1159+
>>> f = x**2 * y**2 + x * y**2
11601160
>>> q, N = f.deflation()
11611161
>>> q, N
1162-
(x + y + 1, [2, 3])
1162+
(x^2*y + x*y, [1, 2])
1163+
>>> q.inflate(N) == f
1164+
True
11631165
"""
11641166
cdef:
1165-
fmpz_vec _shift = fmpz_vec(self.ctx.nvars())
1166-
fmpz_vec stride = fmpz_vec(self.ctx.nvars())
1167+
slong nvars = self.ctx.nvars()
1168+
fmpz_vec shift = fmpz_vec(nvars)
1169+
fmpz_vec stride = fmpz_vec(nvars)
11671170
fmpz_mpoly res = create_fmpz_mpoly(self.ctx)
11681171

1169-
fmpz_mpoly_deflation(_shift.val, stride.val, self.val, self.ctx.val)
1172+
fmpz_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1173+
1174+
for i in range(nvars):
1175+
stride[i] = shift[i].gcd(stride[i])
1176+
shift[i] = 0
11701177

1171-
cdef fmpz_vec zero_shift = fmpz_vec(self.ctx.nvars())
1172-
fmpz_mpoly_deflate(res.val, self.val, zero_shift.val, stride.val, self.ctx.val)
1178+
fmpz_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11731179

11741180
return res, list(stride)
11751181

1176-
def deflation_monom(self) -> tuple[list[int], fmpz_mpoly]:
1182+
def deflation_monom(self) -> tuple[fmpz_mpoly, list[int], fmpz_mpoly]:
11771183
"""
11781184
Compute the exponent vector ``N`` and monomial ``m`` such that ``p(X^(1/N))
11791185
= m * q(X^N)`` for maximal N. Importantly the deflation itself is not computed
@@ -1183,21 +1189,24 @@ cdef class fmpz_mpoly(flint_mpoly):
11831189
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, nametup=('x', 'y'))
11841190
>>> x, y = ctx.gens()
11851191
>>> f = x**3 * y + x * y**4 + x * y
1186-
>>> N, m = f.deflation_monom()
1187-
>>> N, m
1188-
([2, 3], x*y)
1189-
>>> f_deflated = f.deflate(N)
1190-
>>> f_deflated
1191-
x + y + 1
1192-
>>> m * f_deflated.inflate(N)
1192+
>>> fd, N, m = f.deflation_monom()
1193+
>>> fd, N, m
1194+
(x + y + 1, [2, 3], x*y)
1195+
>>> m * fd.inflate(N)
11931196
x^3*y + x*y^4 + x*y
11941197
"""
1195-
cdef fmpz_mpoly monom = create_fmpz_mpoly(self.ctx)
1198+
cdef:
1199+
slong nvars = self.ctx.nvars()
1200+
fmpz_mpoly res = create_fmpz_mpoly(self.ctx)
1201+
fmpz_mpoly monom = create_fmpz_mpoly(self.ctx)
1202+
fmpz_vec shift = fmpz_vec(nvars)
1203+
fmpz_vec stride = fmpz_vec(nvars)
11961204

1197-
stride, _shift = self.deflation_index()
1205+
fmpz_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1206+
fmpz_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(shift).val, self.ctx.val)
1207+
fmpz_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11981208

1199-
fmpz_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(_shift).val, self.ctx.val)
1200-
return list(stride), monom
1209+
return res, list(stride), monom
12011210

12021211
def deflation_index(self) -> tuple[list[int], list[int]]:
12031212
"""

src/flint/types/fmpz_poly.pyx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,8 +630,11 @@ cdef class fmpz_poly(flint_poly):
630630
n. returns ``q, n`` such that ``self == q.inflate(n)``.
631631

632632
>>> f = fmpz_poly([1, 0, 1])
633-
>>> f.deflation()
633+
>>> q, n = f.deflation()
634+
>>> q, n
634635
(x + 1, 2)
636+
>>> q.inflate(n) == f
637+
True
635638
"""
636639
cdef ulong n
637640
if fmpz_poly_is_zero(self.val):
@@ -647,14 +650,17 @@ cdef class fmpz_poly(flint_poly):
647650

648651
>>> f = fmpz_poly([1, 0, 1])
649652
>>> f.deflation_monom()
650-
(1, x)
653+
(x^2 + 1, 1, x)
651654
"""
652655
n, m = self.deflation_index()
653656

654657
cdef fmpz_poly monom = fmpz_poly.__new__(fmpz_poly)
658+
cdef fmpz_poly res = fmpz_poly.__new__(fmpz_poly)
659+
655660
fmpz_poly_set_coeff_ui(monom.val, m, 1)
661+
fmpz_poly_deflate(res.val, self.val, n)
656662

657-
return n, monom
663+
return res, n, monom
658664

659665
def deflation_index(self) -> tuple[int, int]:
660666
"""

src/flint/types/nmod_mpoly.pyx

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,24 +1107,30 @@ cdef class nmod_mpoly(flint_mpoly):
11071107
>>> from flint import Ordering
11081108
>>> ctx = nmod_mpoly_ctx.get_context(2, Ordering.lex, 11, nametup=('x', 'y'))
11091109
>>> x, y = ctx.gens()
1110-
>>> f = x**3 * y + x * y**4 + x * y
1110+
>>> f = x**2 * y**2 + x * y**2
11111111
>>> q, N = f.deflation()
11121112
>>> q, N
1113-
(x + y + 1, [2, 3])
1113+
(x^2*y + x*y, [1, 2])
1114+
>>> q.inflate(N) == f
1115+
True
11141116
"""
11151117
cdef:
1116-
fmpz_vec _shift = fmpz_vec(self.ctx.nvars())
1117-
fmpz_vec stride = fmpz_vec(self.ctx.nvars())
1118+
slong nvars = self.ctx.nvars()
1119+
fmpz_vec shift = fmpz_vec(nvars)
1120+
fmpz_vec stride = fmpz_vec(nvars)
11181121
nmod_mpoly res = create_nmod_mpoly(self.ctx)
11191122

1120-
nmod_mpoly_deflation(_shift.val, stride.val, self.val, self.ctx.val)
1123+
nmod_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1124+
1125+
for i in range(nvars):
1126+
stride[i] = shift[i].gcd(stride[i])
1127+
shift[i] = 0
11211128

1122-
cdef fmpz_vec zero_shift = fmpz_vec(self.ctx.nvars())
1123-
nmod_mpoly_deflate(res.val, self.val, zero_shift.val, stride.val, self.ctx.val)
1129+
nmod_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11241130

11251131
return res, list(stride)
11261132

1127-
def deflation_monom(self) -> tuple[list[int], nmod_mpoly]:
1133+
def deflation_monom(self) -> tuple[nmod_mpoly, list[int], nmod_mpoly]:
11281134
"""
11291135
Compute the exponent vector ``N`` and monomial ``m`` such that ``p(X^(1/N))
11301136
= m * q(X^N)`` for maximal N. Importantly the deflation itself is not computed
@@ -1134,21 +1140,24 @@ cdef class nmod_mpoly(flint_mpoly):
11341140
>>> ctx = nmod_mpoly_ctx.get_context(2, Ordering.lex, 11, nametup=('x', 'y'))
11351141
>>> x, y = ctx.gens()
11361142
>>> f = x**3 * y + x * y**4 + x * y
1137-
>>> N, m = f.deflation_monom()
1138-
>>> N, m
1139-
([2, 3], x*y)
1140-
>>> f_deflated = f.deflate(N)
1141-
>>> f_deflated
1142-
x + y + 1
1143-
>>> m * f_deflated.inflate(N)
1143+
>>> fd, N, m = f.deflation_monom()
1144+
>>> fd, N, m
1145+
(x + y + 1, [2, 3], x*y)
1146+
>>> m * fd.inflate(N)
11441147
x^3*y + x*y^4 + x*y
11451148
"""
1146-
cdef nmod_mpoly monom = create_nmod_mpoly(self.ctx)
1149+
cdef:
1150+
slong nvars = self.ctx.nvars()
1151+
nmod_mpoly res = create_nmod_mpoly(self.ctx)
1152+
nmod_mpoly monom = create_nmod_mpoly(self.ctx)
1153+
fmpz_vec shift = fmpz_vec(nvars)
1154+
fmpz_vec stride = fmpz_vec(nvars)
11471155

1148-
stride, _shift = self.deflation_index()
1156+
nmod_mpoly_deflation(shift.val, stride.val, self.val, self.ctx.val)
1157+
nmod_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(shift).val, self.ctx.val)
1158+
nmod_mpoly_deflate(res.val, self.val, shift.val, stride.val, self.ctx.val)
11491159

1150-
nmod_mpoly_push_term_ui_ffmpz(monom.val, 1, fmpz_vec(_shift).val, self.ctx.val)
1151-
return list(stride), monom
1160+
return res, list(stride), monom
11521161

11531162
def deflation_index(self) -> tuple[list[int], list[int]]:
11541163
"""

0 commit comments

Comments
 (0)