Skip to content

Commit 08bd191

Browse files
committed
Add missing operations for fmpz and fmpq
1 parent 027dd1f commit 08bd191

File tree

4 files changed

+228
-31
lines changed

4 files changed

+228
-31
lines changed

src/flint/_flint.pxd

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ cdef extern from "flint/fmpz.h":
267267
void fmpz_pow_ui(fmpz_t f, fmpz_t g, ulong exp)
268268
void fmpz_powm_ui(fmpz_t f, fmpz_t g, ulong exp, fmpz_t m)
269269
void fmpz_powm(fmpz_t f, fmpz_t g, fmpz_t e, fmpz_t m)
270+
int fmpz_pow_fmpz(fmpz_t f, const fmpz_t g, const fmpz_t x)
270271
int fmpz_sqrtmod(fmpz_t b, fmpz_t a, fmpz_t p)
271272
void fmpz_sqrt(fmpz_t f, fmpz_t g)
272273
void fmpz_sqrtrem(fmpz_t f, fmpz_t r, fmpz_t g)
@@ -310,6 +311,10 @@ cdef extern from "flint/fmpz.h":
310311
int fmpz_jacobi(const fmpz_t a, const fmpz_t p)
311312
int fmpz_is_prime(const fmpz_t n)
312313
int fmpz_is_probabprime(const fmpz_t n)
314+
void fmpz_complement(fmpz_t r, const fmpz_t f)
315+
void fmpz_and(fmpz_t r, const fmpz_t a, const fmpz_t b)
316+
void fmpz_or(fmpz_t r, const fmpz_t a, const fmpz_t b)
317+
void fmpz_xor(fmpz_t r, const fmpz_t a, const fmpz_t b)
313318

314319
cdef extern from "flint/fmpz_factor.h":
315320
ctypedef struct fmpz_factor_struct:
@@ -547,6 +552,7 @@ cdef extern from "flint/fmpq.h":
547552
void fmpq_div(fmpq_t res, fmpq_t op1, fmpq_t op2)
548553
void fmpq_div_fmpz(fmpq_t res, fmpq_t op, fmpz_t x)
549554
int fmpq_mod_fmpz(fmpz_t res, fmpq_t x, fmpz_t mod)
555+
int fmpq_pow_fmpz(fmpq_t a, const fmpq_t b, const fmpz_t e)
550556
int fmpq_reconstruct_fmpz(fmpq_t res, fmpz_t a, fmpz_t m)
551557
int fmpq_reconstruct_fmpz_2(fmpq_t res, fmpz_t a, fmpz_t m, fmpz_t N, fmpz_t D)
552558
mp_bitcnt_t fmpq_height_bits(fmpq_t x)

src/flint/fmpq.pyx

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ cdef class fmpq(flint_scalar):
101101
p = property(numer)
102102
q = property(denom)
103103

104+
# These are the property names in the numeric tower.
105+
numerator = property(numer)
106+
denominator = property(denom)
107+
104108
def repr(self):
105109
if self.q == 1:
106110
return "fmpq(%s)" % self.p
@@ -122,6 +126,15 @@ cdef class fmpq(flint_scalar):
122126
else:
123127
return "%s/%s" % (self.p.str(**kwargs), self.q.str(**kwargs))
124128

129+
def __int__(self):
130+
if self.p >= 0:
131+
return int(self.p // self.q)
132+
else:
133+
return - int((-self.p) // self.q)
134+
135+
def __floor__(self):
136+
return int(self.p // self.q)
137+
125138
def __nonzero__(self):
126139
return not fmpq_is_zero(self.val)
127140

@@ -359,20 +372,28 @@ cdef class fmpq(flint_scalar):
359372
return max(b1, b2)
360373

361374
def __pow__(self, n, z):
375+
cdef fmpz_struct nval[1]
376+
cdef int ntype = FMPZ_UNKNOWN
362377
cdef fmpq v
378+
cdef int success
363379
cdef long e
380+
364381
assert z is None
365-
e = n
366-
if type(self) is fmpq:
367-
v = fmpq.__new__(fmpq)
368-
if e >= 0:
369-
fmpz_pow_ui(fmpq_numref(v.val), fmpq_numref((<fmpq>self).val), e)
370-
fmpz_pow_ui(fmpq_denref(v.val), fmpq_denref((<fmpq>self).val), e)
371-
else:
372-
if fmpq_is_zero((<fmpq>self).val):
373-
raise ZeroDivisionError
374-
fmpz_pow_ui(fmpq_denref(v.val), fmpq_numref((<fmpq>self).val), -e)
375-
fmpz_pow_ui(fmpq_numref(v.val), fmpq_denref((<fmpq>self).val), -e)
376-
return v
377-
return NotImplemented
378382

383+
ntype = fmpz_set_any_ref(nval, n)
384+
if ntype == FMPZ_UNKNOWN:
385+
return NotImplemented
386+
387+
if fmpq_is_zero((<fmpq>self).val) and fmpz_sgn(nval) == -1:
388+
if ntype == FMPZ_TMP: fmpz_clear(nval)
389+
raise ZeroDivisionError
390+
391+
v = fmpq.__new__(fmpq)
392+
success = fmpq_pow_fmpz(v.val, (<fmpq>self).val, nval)
393+
394+
if ntype == FMPZ_TMP: fmpz_clear(nval)
395+
396+
if success:
397+
return v
398+
else:
399+
raise OverflowError("fmpq_pow_fmpz(): exponent too large")

src/flint/fmpz.pyx

Lines changed: 171 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ cdef class fmpz(flint_scalar):
9191
return
9292
raise TypeError("cannot create fmpz from type %s" % type(val))
9393

94+
@property
95+
def numerator(self):
96+
return self
97+
98+
@property
99+
def denominator(self):
100+
return fmpz(1)
101+
94102
# XXX: improve!
95103
def __int__(self):
96104
return fmpz_get_intlong(self.val)
@@ -101,6 +109,9 @@ cdef class fmpz(flint_scalar):
101109
def __index__(self):
102110
return fmpz_get_intlong(self.val)
103111

112+
def __float__(self):
113+
return float(fmpz_get_intlong(self.val))
114+
104115
def __richcmp__(s, t, int op):
105116
cdef bint res = 0
106117
cdef long tl
@@ -334,28 +345,171 @@ cdef class fmpz(flint_scalar):
334345
return u
335346

336347
def __pow__(s, t, m):
337-
cdef ulong exp
338-
u = NotImplemented
339-
if m is not None:
340-
raise NotImplementedError("modular exponentiation")
341-
c = t
342-
u = fmpz.__new__(fmpz)
343-
fmpz_pow_ui((<fmpz>u).val, (<fmpz>s).val, c)
344-
return u
345-
346-
def __rpow__(s, t, m):
347348
cdef fmpz_struct tval[1]
348-
cdef int stype = FMPZ_UNKNOWN
349-
cdef ulong exp
349+
cdef fmpz_struct mval[1]
350+
cdef int ttype = FMPZ_UNKNOWN
351+
cdef int mtype = FMPZ_UNKNOWN
352+
cdef int success
350353
u = NotImplemented
351-
if m is not None:
352-
raise NotImplementedError("modular exponentiation")
353354
ttype = fmpz_set_any_ref(tval, t)
354-
if ttype != FMPZ_UNKNOWN:
355+
if ttype == FMPZ_UNKNOWN:
356+
return NotImplemented
357+
358+
if m is None:
359+
# fmpz_pow_fmpz throws if x is negative
360+
if fmpz_sgn(tval) == -1:
361+
if ttype == FMPZ_TMP: fmpz_clear(tval)
362+
raise ValueError("negative exponent")
363+
355364
u = fmpz.__new__(fmpz)
356-
s_ulong = fmpz_get_ui(s.val)
357-
fmpz_pow_ui((<fmpz>u).val, tval, s_ulong)
365+
success = fmpz_pow_fmpz((<fmpz>u).val, (<fmpz>s).val, tval)
366+
367+
if not success:
368+
if ttype == FMPZ_TMP: fmpz_clear(tval)
369+
raise ValueError("fmpz_pow_fmpz: exponent too large")
370+
else:
371+
# Modular exponentiation
372+
mtype = fmpz_set_any_ref(mval, m)
373+
if mtype != FMPZ_UNKNOWN:
374+
375+
if fmpz_is_zero(mval):
376+
if ttype == FMPZ_TMP: fmpz_clear(tval)
377+
if mtype == FMPZ_TMP: fmpz_clear(mval)
378+
raise ValueError("pow(): modulus cannot be zero")
379+
380+
# The Flint docs say that fmpz_powm will throw if m is zero
381+
# but it also throws if m is negative. Python generally allows
382+
# e.g. pow(2, 2, -3) == (2^2) % (-3) == -2. We could implement
383+
# that here as well but it is not clear how useful it is.
384+
if fmpz_sgn(mval) == -1:
385+
if ttype == FMPZ_TMP: fmpz_clear(tval)
386+
if mtype == FMPZ_TMP: fmpz_clear(mval)
387+
raise ValueError("pow(): negative modulua not supported")
388+
389+
u = fmpz.__new__(fmpz)
390+
fmpz_powm((<fmpz>u).val, (<fmpz>s).val, tval, mval)
391+
358392
if ttype == FMPZ_TMP: fmpz_clear(tval)
393+
if mtype == FMPZ_TMP: fmpz_clear(mval)
394+
return u
395+
396+
def __rpow__(s, t, m):
397+
t = any_as_fmpz(t)
398+
if t is NotImplemented:
399+
return t
400+
return t.__pow__(s, m)
401+
402+
def __lshift__(self, other):
403+
if typecheck(other, fmpz):
404+
other = int(other)
405+
if typecheck(other, int):
406+
u = fmpz.__new__(fmpz)
407+
fmpz_mul_2exp((<fmpz>u).val, self.val, other)
408+
return u
409+
else:
410+
return NotImplemented
411+
412+
def __rlshift__(self, other):
413+
if typecheck(other, int):
414+
u = fmpz.__new__(fmpz)
415+
fmpz_mul_2exp((<fmpz>u).val, fmpz(other).val, int(self))
416+
return u
417+
else:
418+
return NotImplemented
419+
420+
def __rshift__(self, other):
421+
if typecheck(other, fmpz):
422+
other = int(other)
423+
if typecheck(other, int):
424+
u = fmpz.__new__(fmpz)
425+
fmpz_fdiv_q_2exp((<fmpz>u).val, self.val, other)
426+
return u
427+
else:
428+
return NotImplemented
429+
430+
def __rrshift__(self, other):
431+
if typecheck(other, int):
432+
u = fmpz.__new__(fmpz)
433+
fmpz_fdiv_q_2exp((<fmpz>u).val, fmpz(other).val, int(self))
434+
return u
435+
else:
436+
return NotImplemented
437+
438+
def __and__(self, other):
439+
cdef fmpz_struct tval[1]
440+
cdef int ttype = FMPZ_UNKNOWN
441+
ttype = fmpz_set_any_ref(tval, other)
442+
if ttype == FMPZ_UNKNOWN:
443+
return NotImplemented
444+
u = fmpz.__new__(fmpz)
445+
fmpz_and((<fmpz>u).val, self.val, tval)
446+
if ttype == FMPZ_TMP:
447+
fmpz_clear(tval)
448+
return u
449+
450+
def __rand__(self, other):
451+
cdef fmpz_struct tval[1]
452+
cdef int ttype = FMPZ_UNKNOWN
453+
ttype = fmpz_set_any_ref(tval, other)
454+
if ttype == FMPZ_UNKNOWN:
455+
return NotImplemented
456+
u = fmpz.__new__(fmpz)
457+
fmpz_and((<fmpz>u).val, tval, self.val)
458+
if ttype == FMPZ_TMP:
459+
fmpz_clear(tval)
460+
return u
461+
462+
def __or__(self, other):
463+
cdef fmpz_struct tval[1]
464+
cdef int ttype = FMPZ_UNKNOWN
465+
ttype = fmpz_set_any_ref(tval, other)
466+
if ttype == FMPZ_UNKNOWN:
467+
return NotImplemented
468+
u = fmpz.__new__(fmpz)
469+
fmpz_or((<fmpz>u).val, self.val, tval)
470+
if ttype == FMPZ_TMP:
471+
fmpz_clear(tval)
472+
return u
473+
474+
def __ror__(self, other):
475+
cdef fmpz_struct tval[1]
476+
cdef int ttype = FMPZ_UNKNOWN
477+
ttype = fmpz_set_any_ref(tval, other)
478+
if ttype == FMPZ_UNKNOWN:
479+
return NotImplemented
480+
u = fmpz.__new__(fmpz)
481+
fmpz_or((<fmpz>u).val, tval, self.val)
482+
if ttype == FMPZ_TMP:
483+
fmpz_clear(tval)
484+
return u
485+
486+
def __xor__(self, other):
487+
cdef fmpz_struct tval[1]
488+
cdef int ttype = FMPZ_UNKNOWN
489+
ttype = fmpz_set_any_ref(tval, other)
490+
if ttype == FMPZ_UNKNOWN:
491+
return NotImplemented
492+
u = fmpz.__new__(fmpz)
493+
fmpz_xor((<fmpz>u).val, self.val, tval)
494+
if ttype == FMPZ_TMP:
495+
fmpz_clear(tval)
496+
return u
497+
498+
def __rxor__(self, other):
499+
cdef fmpz_struct tval[1]
500+
cdef int ttype = FMPZ_UNKNOWN
501+
ttype = fmpz_set_any_ref(tval, other)
502+
if ttype == FMPZ_UNKNOWN:
503+
return NotImplemented
504+
u = fmpz.__new__(fmpz)
505+
fmpz_xor((<fmpz>u).val, tval, self.val)
506+
if ttype == FMPZ_TMP:
507+
fmpz_clear(tval)
508+
return u
509+
510+
def __invert__(self):
511+
u = fmpz.__new__(fmpz)
512+
fmpz_complement((<fmpz>u).val, self.val)
359513
return u
360514

361515
def gcd(self, other):

test/test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,23 @@ def test_fmpz():
101101
assert (ltype(s) >= rtype(t)) == (s >= t)
102102
if 0 <= t < 10:
103103
assert (ltype(s) ** rtype(t)) == (s ** t)
104+
105+
pow_mod_examples = [
106+
(2, 2, 3, 1),
107+
(2, -1, 5, 3),
108+
(2, 0, 5, 1),
109+
]
110+
for a, b, c, ab_mod_c in pow_mod_examples:
111+
assert pow(a, b, c) == ab_mod_c
112+
assert pow(flint.fmpz(a), b, c) == ab_mod_c
113+
assert pow(a, flint.fmpz(b), c) == ab_mod_c
114+
assert pow(flint.fmpz(a), flint.fmpz(b), c) == ab_mod_c
115+
assert pow(flint.fmpz(a), flint.fmpz(b), flint.fmpz(c)) == ab_mod_c
116+
117+
assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError)
118+
# XXX: Handle negative modulus like int?
119+
assert raises(lambda: pow(flint.fmpz(2), 2, -1), ValueError)
120+
104121
assert flint.fmpz(2) != []
105122
assert +flint.fmpz(0) == 0
106123
assert +flint.fmpz(1) == 1
@@ -139,7 +156,6 @@ def test_fmpz():
139156
big = flint.fmpz(bigstr)
140157
assert big.str() == bigstr
141158
assert big.str(condense=10) == '1111111111{...80 digits...}1111111111'
142-
assert raises(lambda: pow(flint.fmpz(2), 2, 3), NotImplementedError)
143159

144160
def test_fmpz_factor():
145161
assert flint.fmpz(6).gcd(flint.fmpz(9)) == 3

0 commit comments

Comments
 (0)