Skip to content

Commit 26f85d6

Browse files
committed
Add tests for sqrt/gcd/xgcd
1 parent dcbe363 commit 26f85d6

File tree

10 files changed

+75
-16
lines changed

10 files changed

+75
-16
lines changed

src/flint/test/test_all.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def test_fmpz_poly():
454454
assert Z([1,2,-4]).height_bits() == 3
455455
assert Z([1,2,-4]).height_bits(signed=True) == -3
456456
assert Z([1,2,1]).sqrt() == Z([1,1])
457-
assert raises(lambda: Z([1,2,2]).sqrt(), ValueError)
457+
assert raises(lambda: Z([1,2,2]).sqrt(), DomainError)
458458
assert Z([1,0,2,0,3]).deflation() == (Z([1,2,3]), 2)
459459
assert Z([]).deflation() == (Z([]), 1)
460460
assert Z([1,1]).deflation() == (Z([1,1]), 1)
@@ -2204,7 +2204,7 @@ def test_fmpz_mod_poly():
22042204

22052205
# sqrt
22062206
f1 = R_test.random_element(irreducible=True)
2207-
assert raises(lambda: f1.sqrt(), ValueError)
2207+
assert raises(lambda: f1.sqrt(), DomainError)
22082208
assert (f1*f1).sqrt() in [f1, -f1]
22092209

22102210
# deflation
@@ -2777,7 +2777,7 @@ def setbad(obj, i, val):
27772777

27782778
if not composite_characteristic:
27792779
assert P([1, 2, 1]).sqrt() == P([1, 1])
2780-
assert raises(lambda: P([1, 2, 2]).sqrt(), ValueError), f"{P}, {P([1, 2, 2]).sqrt()}"
2780+
assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError)
27812781
elif nmod_poly_will_crash:
27822782
pass
27832783
else:
@@ -3231,7 +3231,7 @@ def quick_poly():
32313231
assert (f * f).sqrt() == f
32323232
if P is flint.fmpz_mpoly:
32333233
assert (f * f).sqrt(assume_perfect_square=True) == f
3234-
assert raises(lambda: quick_poly().sqrt(), ValueError)
3234+
assert raises(lambda: quick_poly().sqrt(), DomainError)
32353235

32363236
p = quick_poly()
32373237
assert p.derivative(0) == p.derivative("x0") == mpoly({(0, 0): 3, (1, 2): 8})
@@ -3364,6 +3364,35 @@ def factor_sqf(p):
33643364
# All tests below would raise
33653365
continue
33663366

3367+
assert S(0).sqrt() == S(0)
3368+
assert S(1).sqrt() == S(1)
3369+
assert S(4).sqrt()**2 == S(4)
3370+
for i in range(100):
3371+
try:
3372+
assert S(i).sqrt() ** 2 == S(i)
3373+
except DomainError:
3374+
pass
3375+
3376+
if characteristic == 0:
3377+
assert raises(lambda: S(-1).sqrt(), DomainError)
3378+
else:
3379+
try:
3380+
assert S(-1).sqrt() ** 2 == S(-1)
3381+
except DomainError:
3382+
pass
3383+
3384+
assert (0*x).sqrt() == 0*x
3385+
assert (1*x/x).sqrt() == 0*x + 1
3386+
assert (4*x/x).sqrt()**2 == 0*x + 4
3387+
for i in range(100):
3388+
try:
3389+
assert (i*x).sqrt() ** 2 == i*x
3390+
except DomainError:
3391+
pass
3392+
assert (x**2).sqrt() == x
3393+
assert (S(4)*x**2).sqrt()**2 == S(4)*x**2
3394+
assert raises(lambda: (x**2 + 1).sqrt(), DomainError)
3395+
33673396
assert factor(0*x) == (S(0), [])
33683397
assert factor(0*x + 1) == (S(1), [])
33693398
assert factor(0*x + 3) == (S(3), [])
@@ -3379,7 +3408,22 @@ def factor_sqf(p):
33793408
assert factor_sqf(x**2) == (S(1), [(x, 2)])
33803409
assert factor_sqf(2*(x+1)) == (S(2), [(x+1, 1)])
33813410

3382-
assert factor(x*(x+1)) == (S(1), [(x, 1), (x+1, 1)])
3411+
assert (2*x).gcd(x) == x
3412+
assert (2*x).gcd(x**2) == x
3413+
assert (2*x).gcd(x**2 + 1) == S(1)
3414+
3415+
if not is_field:
3416+
# primitive gcd over Z
3417+
assert (2*x).gcd(4*x**2) == 2*x
3418+
else:
3419+
# monic gcd over Q, Z/pZ and GF(p^d)
3420+
assert (2*x).gcd(4*x**2) == x
3421+
3422+
if is_field and y is None:
3423+
# xgcd is defined and consistent for all univariate polynomials
3424+
# over a field (Q, Z/pZ, GF(p^d)).
3425+
assert (2*x).xgcd(4*x) == (x, P(0), P(1)/4)
3426+
assert (2*x).xgcd(4*x**2+1) == (P(1), -2*x, P(1))
33833427

33843428
# mpoly types have a slightly different squarefree factorisation
33853429
# because they handle trivial factors differently. It looks like a
@@ -3388,6 +3432,7 @@ def factor_sqf(p):
33883432
#
33893433
# Maybe it is worth making them consistent by absorbing the power
33903434
# of x into a factor of equal multiplicity.
3435+
assert factor(x*(x+1)) == (S(1), [(x, 1), (x+1, 1)])
33913436
if y is None:
33923437
# *_poly types
33933438
assert factor_sqf(x*(x+1)) == (S(1), [(x**2+x, 1)])
@@ -3399,7 +3444,7 @@ def factor_sqf(p):
33993444
# a unique multiplicity.
34003445
assert factor_sqf(x**2*(x+1)) == (S(1), [(x+1, 1), (x, 2)])
34013446

3402-
# This is the same for all types because there is no tivial monomial
3447+
# This is the same for all types because there is no trivial monomial
34033448
# factor to extract.
34043449
assert factor((x-1)*(x+1)) == (S(1), sort([(x-1, 1), (x+1, 1)]))
34053450
assert factor_sqf((x-1)*(x+1)) == (S(1), [(x**2-1, 1)])
@@ -3433,6 +3478,8 @@ def factor_sqf(p):
34333478

34343479
if y is not None:
34353480

3481+
# *_mpoly types
3482+
34363483
assert factor(x*y+1) == (S(1), [(x*y+1, 1)])
34373484
assert factor(x*y) == (S(1), [(x, 1), (y, 1)])
34383485

@@ -3451,6 +3498,13 @@ def factor_sqf(p):
34513498
else:
34523499
assert factor(p) == factor_sqf(p) == (S(2)/7, [(7*p/2, 1)])
34533500

3501+
if not is_field:
3502+
# primitive gcd over Z
3503+
assert (2*(x+y)).gcd(4*(x+y)**2) == 2*(x+y)
3504+
else:
3505+
# monic gcd over Q, Z/pZ and GF(p^d)
3506+
assert (2*(x+y)).gcd(4*(x+y)**2) == x + y
3507+
34543508

34553509
def _all_matrices():
34563510
"""Return a list of matrix types and scalar types."""

src/flint/types/fmpq_mpoly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ cdef class fmpq_mpoly(flint_mpoly):
925925
if fmpq_mpoly_sqrt(res.val, self.val, self.ctx.val):
926926
return res
927927
else:
928-
raise ValueError("polynomial is not a perfect square")
928+
raise DomainError("polynomial is not a perfect square")
929929

930930
def factor(self):
931931
"""

src/flint/types/fmpz_mod_mpoly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
827827
if fmpz_mod_mpoly_sqrt(res.val, self.val, self.ctx.val):
828828
return res
829829
else:
830-
raise ValueError("polynomial is not a perfect square")
830+
raise DomainError("polynomial is not a perfect square")
831831

832832
def factor(self):
833833
"""

src/flint/types/fmpz_mod_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,7 @@ cdef class fmpz_mod_poly(flint_poly):
14961496
res.val, self.val, res.ctx.mod.val
14971497
)
14981498
if check != 1:
1499-
raise ValueError(
1499+
raise DomainError(
15001500
f"Cannot compute square-root {self}"
15011501
)
15021502
return res

src/flint/types/fmpz_mpoly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ cdef class fmpz_mpoly(flint_mpoly):
903903
if fmpz_mpoly_sqrt_heap(res.val, self.val, self.ctx.val, not assume_perfect_square):
904904
return res
905905
else:
906-
raise ValueError("polynomial is not a perfect square")
906+
raise DomainError("polynomial is not a perfect square")
907907

908908
def factor(self):
909909
"""

src/flint/types/fmpz_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ cdef class fmpz_poly(flint_poly):
591591
if fmpz_poly_sqrt(v.val, self.val):
592592
return v
593593
else:
594-
raise ValueError(f"Cannot compute square root of {self}")
594+
raise DomainError(f"Cannot compute square root of {self}")
595595

596596
def deflation(self):
597597
cdef fmpz_poly v

src/flint/types/fq_default_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ cdef class fq_default_poly(flint_poly):
982982
res.val, self.val, res.ctx.field.val
983983
)
984984
if check != 1:
985-
raise ValueError(
985+
raise DomainError(
986986
f"Cannot compute square-root {self}"
987987
)
988988
return res

src/flint/types/nmod.pyx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,16 @@ cdef class nmod(flint_scalar):
261261
262262
"""
263263
cdef nmod r
264+
cdef mp_limb_t val
264265
r = nmod.__new__(nmod)
265266
r.mod = self.mod
266-
r.val = n_sqrtmod(self.val, self.mod.n)
267267

268-
if r.val == 0:
268+
if self.val == 0:
269+
return r
270+
271+
val = n_sqrtmod(self.val, self.mod.n)
272+
if val == 0:
269273
raise DomainError("no square root exists for %s mod %s" % (self.val, self.mod.n))
270274

275+
r.val = val
271276
return r

src/flint/types/nmod_mpoly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ cdef class nmod_mpoly(flint_mpoly):
803803
if nmod_mpoly_sqrt(res.val, self.val, self.ctx.val):
804804
return res
805805
else:
806-
raise ValueError("polynomial is not a perfect square")
806+
raise DomainError("polynomial is not a perfect square")
807807

808808
def factor(self):
809809
"""

src/flint/types/nmod_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ cdef class nmod_poly(flint_poly):
700700
if nmod_poly_sqrt(res.val, self.val):
701701
return res
702702
else:
703-
raise ValueError(f"Cannot compute square root of {self}")
703+
raise DomainError(f"Cannot compute square root of {self}")
704704

705705
def deflation(self):
706706
cdef nmod_poly v

0 commit comments

Comments
 (0)