Skip to content

Commit dcbe363

Browse files
committed
Add .sqrt() for all scalar types and add docstrings
1 parent fe7265d commit dcbe363

File tree

10 files changed

+238
-37
lines changed

10 files changed

+238
-37
lines changed

src/flint/test/test_all.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def test_fmpz_factor():
285285
(1296814508839693536173209832765271992846610925502473758289451540212712414540699659186801, 1)]
286286

287287
def test_fmpz_functions():
288-
T, F, VE, OE = True, False, ValueError, OverflowError
288+
T, F, VE, OE, DE = True, False, ValueError, OverflowError, DomainError
289289
cases = [
290290
# (f, [f(-1), f(0), f(1), f(2), ... f(10)]),
291291
(lambda n: flint.fmpz(n).is_prime(),
@@ -331,11 +331,11 @@ def test_fmpz_functions():
331331
(lambda n: flint.fmpz(n).euler_phi(),
332332
[0, 0, 1, 1, 2, 2, 4, 2, 6, 4, 6, 4]),
333333
(lambda n: flint.fmpz(n).isqrt(),
334-
[VE, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]),
334+
[DE, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3]),
335335
(lambda n: flint.fmpz(n).sqrtrem(),
336-
[VE, (0, 0), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (3, 0), (3, 1)]),
336+
[DE, (0, 0), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (3, 0), (3, 1)]),
337337
(lambda n: flint.fmpz(n).sqrtmod(11),
338-
[VE, 0, 1, VE, 5, 2, 4, VE, VE, VE, 3, VE]),
338+
[DE, 0, 1, DE, 5, 2, 4, DE, DE, DE, 3, DE]),
339339
(lambda n: flint.fmpz(n).root(3),
340340
[VE, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2]),
341341
(lambda n: flint.fmpz(n).jacobi(3),
@@ -2781,8 +2781,6 @@ def setbad(obj, i, val):
27812781
elif nmod_poly_will_crash:
27822782
pass
27832783
else:
2784-
# fmpz_mod_poly.sqrt() also crashes here:
2785-
# GR_MUST_SUCCEED failure: src/fmpz_mod_poly/sqrt_series.c
27862784
assert raises(lambda: P([1, 2, 1]).sqrt(), DomainError)
27872785

27882786
if P == flint.fmpq_poly:
@@ -3352,21 +3350,6 @@ def factor_sqf(p):
33523350
check(p, coeff, factors)
33533351
return coeff, sort(factors)
33543352

3355-
def sqrt(a):
3356-
if type(x) is flint.fq_default_poly:
3357-
try:
3358-
return S(a).sqrt()
3359-
except ValueError:
3360-
return None
3361-
elif characteristic != 0:
3362-
# XXX: fmpz(0).sqrtmod crashes
3363-
try:
3364-
return flint.fmpz(-1).sqrtmod(characteristic)
3365-
except ValueError:
3366-
return None
3367-
else:
3368-
return None
3369-
33703353
for P, S, [x, y], is_field, characteristic in _all_polys_mpolys():
33713354

33723355
if characteristic != 0 and not characteristic.is_prime():
@@ -3397,25 +3380,39 @@ def sqrt(a):
33973380
assert factor_sqf(2*(x+1)) == (S(2), [(x+1, 1)])
33983381

33993382
assert factor(x*(x+1)) == (S(1), [(x, 1), (x+1, 1)])
3383+
3384+
# mpoly types have a slightly different squarefree factorisation
3385+
# because they handle trivial factors differently. It looks like a
3386+
# monomial gcd is extracted but not recombined so the square-free
3387+
# factors might not have unique multiplicities.
3388+
#
3389+
# Maybe it is worth making them consistent by absorbing the power
3390+
# of x into a factor of equal multiplicity.
34003391
if y is None:
3392+
# *_poly types
34013393
assert factor_sqf(x*(x+1)) == (S(1), [(x**2+x, 1)])
34023394
else:
3403-
# mpoly types have a different squarefree factorisation because
3404-
# they handle trivial factors differently...
3405-
#
3406-
# Maybe it is worth making them consistent by absorbing the power
3407-
# of x into a factor of equal multiplicity.
3395+
# *_mpoly types
34083396
assert factor_sqf(x*(x+1)) == (S(1), [(x, 1), (x+1, 1)])
34093397

3398+
# This is the same for all types because the extracted monomial has
3399+
# a unique multiplicity.
34103400
assert factor_sqf(x**2*(x+1)) == (S(1), [(x+1, 1), (x, 2)])
34113401

3402+
# This is the same for all types because there is no tivial monomial
3403+
# factor to extract.
34123404
assert factor((x-1)*(x+1)) == (S(1), sort([(x-1, 1), (x+1, 1)]))
34133405
assert factor_sqf((x-1)*(x+1)) == (S(1), [(x**2-1, 1)])
34143406

3407+
# Some finite fields have sqrt(-1) so we can factor x**2 + 1
3408+
try:
3409+
i = S(-1).sqrt()
3410+
except DomainError:
3411+
i = None
3412+
34153413
p = 3*(x-1)**2*(x+1)**2*(x**2 + 1)**3
34163414
assert factor_sqf(p) == (S(3), [(x**2 - 1, 2), (x**2 + 1, 3)])
34173415

3418-
i = sqrt(-1)
34193416
if i is not None:
34203417
assert factor(p) == (S(3), sort([(x+1, 2), (x-1, 2), (x+i, 3), (x-i, 3)]))
34213418
else:
@@ -4163,7 +4160,7 @@ def test_fq_default():
41634160
nqr = gf.random_element()
41644161
if not nqr.is_square():
41654162
break
4166-
assert raises(lambda: nqr.sqrt(), ValueError)
4163+
assert raises(lambda: nqr.sqrt(), DomainError)
41674164

41684165

41694166
def test_fq_default_poly():

src/flint/types/fmpq.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,17 @@ cdef class fmpq(flint_scalar):
486486
return v
487487
else:
488488
raise OverflowError("fmpq_pow_fmpz(): exponent too large")
489+
490+
def sqrt(self):
491+
"""
492+
Return exact rational square root of self or raise an error.
493+
494+
>>> fmpq(9, 4).sqrt()
495+
3/2
496+
>>> fmpq(8).sqrt()
497+
Traceback (most recent call last):
498+
...
499+
flint.utils.flint_exceptions.DomainError: not a square number
500+
501+
"""
502+
return fmpq(self.numer().sqrt(), self.denom().sqrt())

src/flint/types/fmpq_mpoly.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,19 @@ cdef class fmpq_mpoly(flint_mpoly):
859859
return res
860860

861861
def leading_coefficient(self):
862+
"""
863+
Leading coefficient in the monomial ordering.
864+
865+
>>> from flint import Ordering
866+
>>> ctx = fmpq_mpoly_ctx(2, Ordering.lex, ['x', 'y'])
867+
>>> x, y = ctx.gens()
868+
>>> p = 2*x*y + 3*x + 4*y**2 + 5
869+
>>> p
870+
2*x*y + 3*x + 4*y^2 + 5
871+
>>> p.leading_coefficient()
872+
2
873+
874+
"""
862875
if fmpq_mpoly_is_zero(self.val, self.ctx.val):
863876
return fmpq(0)
864877
else:

src/flint/types/fmpz.pyx

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -883,30 +883,102 @@ cdef class fmpz(flint_scalar):
883883
return self.bit_length()
884884

885885
def isqrt(self):
886+
"""
887+
Return square root rounded down.
888+
889+
>>> fmpz(9).isqrt()
890+
3
891+
>>> fmpz(8).isqrt()
892+
2
893+
894+
"""
895+
cdef fmpz v
896+
897+
if fmpz_sgn(self.val) < 0:
898+
raise DomainError("integer square root of a negative number")
899+
900+
v = fmpz()
901+
fmpz_sqrt(v.val, self.val)
902+
return v
903+
904+
def sqrt(self):
905+
"""
906+
Return exact integer square root of self or raise an error.
907+
908+
>>> fmpz(9).sqrt()
909+
3
910+
>>> fmpz(8).sqrt()
911+
Traceback (most recent call last):
912+
...
913+
flint.utils.flint_exceptions.DomainError: not a square number
914+
915+
"""
886916
cdef fmpz v
917+
887918
if fmpz_sgn(self.val) < 0:
888-
raise ValueError("integer square root of a negative number")
919+
raise DomainError("integer square root of a negative number")
920+
889921
v = fmpz()
890922
fmpz_sqrt(v.val, self.val)
923+
924+
c = fmpz()
925+
fmpz_mul(c.val, v.val, v.val)
926+
if not fmpz_equal(c.val, self.val):
927+
raise DomainError("not a square number")
928+
891929
return v
892930

893931
def sqrtrem(self):
932+
"""
933+
Return the integer square root of self and remainder.
934+
935+
>>> fmpz(9).sqrtrem()
936+
(3, 0)
937+
>>> fmpz(8).sqrtrem()
938+
(2, 4)
939+
>>> c = fmpz(123456789012345678901234567890)
940+
>>> u, v = c.sqrtrem()
941+
>>> u ** 2 + v == c
942+
True
943+
944+
"""
894945
cdef fmpz u, v
946+
895947
if fmpz_sgn(self.val) < 0:
896-
raise ValueError("integer square root of a negative number")
948+
raise DomainError("integer square root of a negative number")
949+
897950
u = fmpz()
898951
v = fmpz()
899952
fmpz_sqrtrem(u.val, v.val, self.val)
953+
900954
return u, v
901955

902956
# warning: m should be prime!
903-
# also if self is zero this crashes...
904-
def sqrtmod(self, m):
957+
def sqrtmod(self, p):
958+
"""
959+
Return modular square root of self modulo *p* or raise an error.
960+
961+
>>> fmpz(10).sqrtmod(13)
962+
6
963+
>>> (6**2) % 13
964+
10
965+
>>> fmpz(11).sqrtmod(13)
966+
Traceback (most recent call last):
967+
...
968+
flint.utils.flint_exceptions.DomainError: modular square root does not exist
969+
970+
The modulus *p* must be a prime number.
971+
"""
905972
cdef fmpz v
973+
906974
v = fmpz()
907-
m = fmpz(m)
908-
if not fmpz_sqrtmod(v.val, self.val, (<fmpz>m).val):
909-
raise ValueError("unable to compute modular square root")
975+
if fmpz_is_zero(self.val):
976+
return v
977+
978+
p = fmpz(p)
979+
if not fmpz_sqrtmod(v.val, self.val, (<fmpz>p).val):
980+
raise DomainError("modular square root does not exist")
981+
910982
return v
911983

912984
def root(self, long n):

src/flint/types/fmpz_mod.pyx

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from flint.flintlib.fmpz cimport(
1010
fmpz_is_probabprime,
1111
fmpz_mul,
1212
fmpz_invmod,
13+
fmpz_sqrtmod,
1314
fmpz_divexact,
1415
fmpz_gcd,
1516
fmpz_is_one,
@@ -29,6 +30,9 @@ from flint.types.fmpz cimport(
2930
cimport cython
3031
cimport libc.stdlib
3132

33+
from flint.utils.flint_exceptions import DomainError
34+
35+
3236
cdef class fmpz_mod_ctx:
3337
r"""
3438
Context object for creating :class:`~.fmpz_mod` initalised
@@ -578,3 +582,34 @@ cdef class fmpz_mod(flint_scalar):
578582
)
579583

580584
return res
585+
586+
def sqrt(self):
587+
"""
588+
Return the square root of this ``fmpz_mod`` or raise an exception.
589+
590+
>>> ctx = fmpz_mod_ctx(13)
591+
>>> s = ctx(10).sqrt()
592+
>>> s
593+
fmpz_mod(6, 13)
594+
>>> s * s
595+
fmpz_mod(10, 13)
596+
>>> ctx(11).sqrt()
597+
Traceback (most recent call last):
598+
...
599+
flint.utils.flint_exceptions.DomainError: no square root exists for 11 mod 13
600+
601+
The modulus must be prime.
602+
603+
"""
604+
cdef fmpz_mod v
605+
606+
v = fmpz_mod.__new__(fmpz_mod)
607+
v.ctx = self.ctx
608+
609+
if fmpz_is_zero(self.val):
610+
return v
611+
612+
if not fmpz_sqrtmod(v.val, self.val, self.ctx.val.n):
613+
raise DomainError("no square root exists for {} mod {}".format(self, self.ctx.modulus()))
614+
615+
return v

src/flint/types/fmpz_mod_mpoly.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,19 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
758758
return res
759759

760760
def leading_coefficient(self):
761+
"""
762+
Leading coefficient in the monomial ordering.
763+
764+
>>> from flint import Ordering
765+
>>> ctx = fmpz_mod_mpoly_ctx(2, Ordering.lex, ['x', 'y'], 11)
766+
>>> x, y = ctx.gens()
767+
>>> p = 2*x*y + 3*x + 4*y**2 + 5
768+
>>> p
769+
2*x*y + 3*x + 4*y^2 + 5
770+
>>> p.leading_coefficient()
771+
2
772+
773+
"""
761774
if fmpz_mod_mpoly_is_zero(self.val, self.ctx.val):
762775
return fmpz(0)
763776
else:

src/flint/types/fmpz_mpoly.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,19 @@ cdef class fmpz_mpoly(flint_mpoly):
837837
return res
838838

839839
def leading_coefficient(self):
840+
"""
841+
Leading coefficient in the monomial ordering.
842+
843+
>>> from flint import Ordering
844+
>>> ctx = fmpz_mpoly_ctx(2, Ordering.lex, ['x', 'y'])
845+
>>> x, y = ctx.gens()
846+
>>> p = 2*x*y + 3*x + 4*y**2 + 5
847+
>>> p
848+
2*x*y + 3*x + 4*y^2 + 5
849+
>>> p.leading_coefficient()
850+
2
851+
852+
"""
840853
if fmpz_mpoly_is_zero(self.val, self.ctx.val):
841854
return fmpz(0)
842855
else:

src/flint/types/fq_default.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from flint.types.fmpz_mod_poly cimport fmpz_mod_poly, fmpz_mod_poly_ctx
77
from flint.types.nmod_poly cimport nmod_poly
88
from flint.utils.typecheck cimport typecheck
99

10+
from flint.utils.flint_exceptions import DomainError
11+
1012
# Allow the type to be denoted by strings or integers
1113
FQ_TYPES = {
1214
"FQ_ZECH" : 1,
@@ -750,7 +752,7 @@ cdef class fq_default(flint_scalar):
750752
check = fq_default_sqrt(res.val, self.val, self.ctx.val)
751753
if check:
752754
return res
753-
raise ValueError("element is not a square")
755+
raise DomainError("element is not a square")
754756

755757
def is_square(self):
756758
"""

0 commit comments

Comments
 (0)