Skip to content

Commit cba4721

Browse files
committed
Add rounding functions for fmpq and tests
1 parent f347f18 commit cba4721

File tree

3 files changed

+189
-25
lines changed

3 files changed

+189
-25
lines changed

src/flint/fmpq.pyx

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,20 @@ cdef class fmpq(flint_scalar):
136136
return - int((-self.p) // self.q)
137137

138138
def __floor__(self):
139-
return int(self.p // self.q)
139+
return self.floor()
140+
141+
def __ceil__(self):
142+
return self.ceil()
143+
144+
def __trunc__(self):
145+
return self.trunc()
140146

141147
def __nonzero__(self):
142148
return not fmpq_is_zero(self.val)
143149

150+
def __round__(self, ndigits=None):
151+
return self.round(ndigits)
152+
144153
def __pos__(self):
145154
return self
146155

@@ -351,6 +360,37 @@ cdef class fmpq(flint_scalar):
351360
fmpz_cdiv_q(r.val, fmpq_numref(self.val), fmpq_denref(self.val))
352361
return r
353362

363+
def trunc(self):
364+
"""
365+
Truncation function.
366+
367+
>>> fmpq(3,2).trunc()
368+
1
369+
>>> fmpq(-3,2).trunc()
370+
-1
371+
"""
372+
cdef fmpz r = fmpz.__new__(fmpz)
373+
fmpz_tdiv_q(r.val, fmpq_numref(self.val), fmpq_denref(self.val))
374+
return r
375+
376+
def round(self, ndigits=None):
377+
"""
378+
Rounding function.
379+
380+
>>> fmpq(3,2).round()
381+
2
382+
>>> fmpq(-3,2).round()
383+
-2
384+
"""
385+
from fractions import Fraction
386+
fself = Fraction(int(self.p), int(self.q))
387+
if ndigits is not None:
388+
fround = round(fself, ndigits)
389+
return fmpq(fround.numerator, fround.denominator)
390+
else:
391+
fround = round(fself)
392+
return fmpz(fround)
393+
354394
def __hash__(self):
355395
from fractions import Fraction
356396
return hash(Fraction(int(self.p), int(self.q), _normalize=False))

src/flint/fmpz.pyx

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ cdef class fmpz(flint_scalar):
369369

370370
if not success:
371371
if ttype == FMPZ_TMP: fmpz_clear(tval)
372-
raise ValueError("fmpz_pow_fmpz: exponent too large")
372+
raise OverflowError("fmpz_pow_fmpz: exponent too large")
373373
else:
374374
# Modular exponentiation
375375
mtype = fmpz_set_any_ref(mval, m)
@@ -406,16 +406,21 @@ cdef class fmpz(flint_scalar):
406406
if typecheck(other, fmpz):
407407
other = int(other)
408408
if typecheck(other, int):
409+
if other < 0:
410+
raise ValueError("negative shift count")
409411
u = fmpz.__new__(fmpz)
410412
fmpz_mul_2exp((<fmpz>u).val, self.val, other)
411413
return u
412414
else:
413415
return NotImplemented
414416

415417
def __rlshift__(self, other):
418+
iself = int(self)
419+
if iself < 0:
420+
raise ValueError("negative shift count")
416421
if typecheck(other, int):
417422
u = fmpz.__new__(fmpz)
418-
fmpz_mul_2exp((<fmpz>u).val, fmpz(other).val, int(self))
423+
fmpz_mul_2exp((<fmpz>u).val, fmpz(other).val, iself)
419424
return u
420425
else:
421426
return NotImplemented
@@ -424,16 +429,21 @@ cdef class fmpz(flint_scalar):
424429
if typecheck(other, fmpz):
425430
other = int(other)
426431
if typecheck(other, int):
432+
if other < 0:
433+
raise ValueError("negative shift count")
427434
u = fmpz.__new__(fmpz)
428435
fmpz_fdiv_q_2exp((<fmpz>u).val, self.val, other)
429436
return u
430437
else:
431438
return NotImplemented
432439

433440
def __rrshift__(self, other):
441+
iself = int(self)
442+
if iself < 0:
443+
raise ValueError("negative shift count")
434444
if typecheck(other, int):
435445
u = fmpz.__new__(fmpz)
436-
fmpz_fdiv_q_2exp((<fmpz>u).val, fmpz(other).val, int(self))
446+
fmpz_fdiv_q_2exp((<fmpz>u).val, fmpz(other).val, iself)
437447
return u
438448
else:
439449
return NotImplemented

test/test.py

Lines changed: 135 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import sys
2-
import flint
2+
import math
33
import operator
4+
import pickle
45
import doctest
56

7+
import flint
8+
69
if sys.version_info[0] >= 3:
710
long = int
811

@@ -85,6 +88,20 @@ def test_fmpz():
8588
assert ltype(s) + rtype(t) == s + t
8689
assert ltype(s) - rtype(t) == s - t
8790
assert ltype(s) * rtype(t) == s * t
91+
assert ltype(s) & rtype(t) == s & t
92+
assert ~ltype(s) == ~s
93+
94+
# XXX: Some sort of bug here means that the following fails
95+
# for values of s and t bigger than the word size.
96+
if abs(s) < 2**62 and abs(t) < 2**62:
97+
assert ltype(s) | rtype(t) == s | t
98+
else:
99+
# This still works so somehow the internal representation
100+
# of the fmpz made by fmpz_or is such that fmpz_equal
101+
# returns false even though the values look the same.
102+
assert str(ltype(s) | rtype(t)) == str(s | t)
103+
104+
assert ltype(s) ^ rtype(t) == s ^ t
88105
if t == 0:
89106
assert raises(lambda: ltype(s) // rtype(t), ZeroDivisionError)
90107
assert raises(lambda: ltype(s) % rtype(t), ZeroDivisionError)
@@ -101,6 +118,20 @@ def test_fmpz():
101118
assert (ltype(s) >= rtype(t)) == (s >= t)
102119
if 0 <= t < 10:
103120
assert (ltype(s) ** rtype(t)) == (s ** t)
121+
assert ltype(s) << rtype(t) == s << t
122+
assert ltype(s) >> rtype(t) == s >> t
123+
elif -10 <= t < 0:
124+
assert raises(lambda: ltype(s) << rtype(t), ValueError)
125+
assert raises(lambda: ltype(s) >> rtype(t), ValueError)
126+
127+
assert 2 ** flint.fmpz(2) == 4
128+
assert type(2 ** flint.fmpz(2)) == flint.fmpz
129+
assert raises(lambda: () ** flint.fmpz(1), TypeError)
130+
assert raises(lambda: flint.fmpz(1) ** (), TypeError)
131+
assert raises(lambda: flint.fmpz(1) ** -1, ValueError)
132+
133+
mega = flint.fmpz(2) ** 8000000
134+
assert raises(lambda: mega ** mega, OverflowError)
104135

105136
pow_mod_examples = [
106137
(2, 2, 3, 1),
@@ -118,6 +149,19 @@ def test_fmpz():
118149
# XXX: Handle negative modulus like int?
119150
assert raises(lambda: pow(flint.fmpz(2), 2, -1), ValueError)
120151

152+
f = flint.fmpz(2)
153+
assert f.numerator == f
154+
assert type(f.numerator) is flint.fmpz
155+
assert f.denominator == 1
156+
assert type(f.denominator) is flint.fmpz
157+
158+
assert int(f) == 2
159+
assert type(int(f)) is int
160+
assert operator.index(f) == 2
161+
assert type(operator.index(f)) is int
162+
assert float(f) == 2.0
163+
assert type(float(f)) is float
164+
121165
assert flint.fmpz(2) != []
122166
assert +flint.fmpz(0) == 0
123167
assert +flint.fmpz(1) == 1
@@ -128,30 +172,68 @@ def test_fmpz():
128172
assert abs(flint.fmpz(0)) == 0
129173
assert abs(flint.fmpz(1)) == 1
130174
assert abs(flint.fmpz(-1)) == 1
131-
assert int(flint.fmpz(2)) == 2
132-
assert isinstance(int(flint.fmpz(2)), int)
133-
assert long(flint.fmpz(2)) == 2
134-
assert isinstance(long(flint.fmpz(2)), long)
135-
l = [1, 2, 3]
136-
l[flint.fmpz(1)] = -2
137-
assert l == [1, -2, 3]
138-
d = {flint.fmpz(2): 3}
139-
d[flint.fmpz(2)] = -1
140-
assert d == {flint.fmpz(2): -1}
175+
176+
assert bool(flint.fmpz(0)) == False
177+
assert bool(flint.fmpz(1)) == True
178+
141179
assert flint.fmpz(2).bit_length() == 2
142180
assert flint.fmpz(-2).bit_length() == 2
143181
assert flint.fmpz(2).height_bits() == 2
144182
assert flint.fmpz(-2).height_bits() == 2
145183
assert flint.fmpz(2).height_bits(signed=True) == 2
146184
assert flint.fmpz(-2).height_bits(signed=True) == -2
185+
186+
f1 = flint.fmpz(1)
187+
f2 = flint.fmpz(2)
188+
f3 = flint.fmpz(3)
189+
f8 = flint.fmpz(8)
190+
191+
assert f2 << 2 == 8
192+
assert f2 << f2 == 8
193+
assert 2 << f2 == 8
194+
assert raises(lambda: f2 << -1, ValueError)
195+
assert raises(lambda: 2 << -f1, ValueError)
196+
197+
assert f8 >> 2 == f2
198+
assert f8 >> f2 == f2
199+
assert 8 >> f2 == f2
200+
assert raises(lambda: f2 >> -1, ValueError)
201+
assert raises(lambda: 2 >> -f1, ValueError)
202+
203+
assert f2 & 3 == 2
204+
assert f2 & f3 == 2
205+
assert 2 & f3 == 2
206+
assert f2 | 3 == 3
207+
assert f2 | f3 == 3
208+
assert 2 | f3 == 3
209+
assert f2 ^ 3 == 1
210+
assert f2 ^ f3 == 1
211+
assert 2 ^ f3 == 1
212+
213+
assert raises(lambda: f2 << (), TypeError)
214+
assert raises(lambda: () << f2, TypeError)
215+
assert raises(lambda: f2 >> (), TypeError)
216+
assert raises(lambda: () >> f2, TypeError)
217+
assert raises(lambda: f2 & (), TypeError)
218+
assert raises(lambda: () & f2, TypeError)
219+
assert raises(lambda: f2 | (), TypeError)
220+
assert raises(lambda: () | f2, TypeError)
221+
assert raises(lambda: f2 ^ (), TypeError)
222+
assert raises(lambda: () ^ f2, TypeError)
223+
224+
l = [1, 2, 3]
225+
l[flint.fmpz(1)] = -2
226+
assert l == [1, -2, 3]
227+
d = {flint.fmpz(2): 3}
228+
d[flint.fmpz(2)] = -1
229+
230+
assert d == {flint.fmpz(2): -1}
147231
ctx.pretty = False
148232
assert repr(flint.fmpz(0)) == "fmpz(0)"
149233
assert repr(flint.fmpz(-27)) == "fmpz(-27)"
150234
ctx.pretty = True
151235
assert repr(flint.fmpz(0)) == "0"
152236
assert repr(flint.fmpz(-27)) == "-27"
153-
assert bool(flint.fmpz(0)) == False
154-
assert bool(flint.fmpz(1)) == True
155237
bigstr = '1' * 100
156238
big = flint.fmpz(bigstr)
157239
assert big.str() == bigstr
@@ -635,6 +717,7 @@ def test_fmpq():
635717
assert 0 == Q(0)
636718
assert Q(2) != 1
637719
assert 1 != Q(2)
720+
assert Q(1) != ()
638721
assert Q(1,2) != 1
639722
assert Q(2,3) == Q(flint.fmpz(2),long(3))
640723
assert Q(-2,-4) == Q(1,2)
@@ -672,6 +755,10 @@ def test_fmpq():
672755
# XXX: This should NotImplementedError or something.
673756
assert raises(lambda: pow(Q(1,2),2,3), AssertionError)
674757

758+
megaz = flint.fmpz(2) ** 8000000
759+
megaq = Q(megaz)
760+
assert raises(lambda: megaq ** megaz, OverflowError)
761+
675762
assert raises(lambda: Q(1,2) + [], TypeError)
676763
assert raises(lambda: Q(1,2) - [], TypeError)
677764
assert raises(lambda: Q(1,2) * [], TypeError)
@@ -700,14 +787,16 @@ def test_fmpq():
700787
assert (Q(1,2) >= Q(1,2)) is True
701788
assert raises(lambda: Q(1,2) > [], TypeError)
702789
assert raises(lambda: [] < Q(1,2), TypeError)
790+
703791
ctx.pretty = False
704792
assert repr(Q(-2,3)) == "fmpq(-2,3)"
705793
assert repr(Q(3)) == "fmpq(3)"
706794
ctx.pretty = True
707795
assert str(Q(-2,3)) == "-2/3"
708796
assert str(Q(3)) == "3"
709-
assert Q(2,3).p == Q(2,3).numer() == 2
710-
assert Q(2,3).q == Q(2,3).denom() == 3
797+
798+
assert Q(2,3).p == Q(2,3).numer() == Q(2,3).numerator == 2
799+
assert Q(2,3).q == Q(2,3).denom() == Q(2,3).denominator == 3
711800
assert +Q(5,7) == Q(5,7)
712801
assert -Q(5,7) == Q(-5,7)
713802
assert -Q(-5,7) == Q(5,7)
@@ -721,12 +810,26 @@ def test_fmpq():
721810
assert Q(-5,3).floor() == flint.fmpz(-2)
722811
assert Q(5,3).ceil() == flint.fmpz(2)
723812
assert Q(-5,3).ceil() == flint.fmpz(-1)
724-
# XXX: Need __floor__ etc.
725-
#
726-
# assert math.floor(Q(5,3)) == flint.fmpz(1)
727-
# assert math.ceil(Q(5,3)) == flint.fmpz(2)
728-
# assert math.trunc(Q(5,3)) == flint.fmpz(2)
729-
# assert round(Q(5,3)) == 2
813+
814+
assert int(Q(5,3)) == flint.fmpz(1)
815+
assert math.floor(Q(5,3)) == flint.fmpz(1)
816+
assert math.ceil(Q(5,3)) == flint.fmpz(2)
817+
assert math.trunc(Q(5,3)) == flint.fmpz(1)
818+
assert round(Q(5,3)) == 2
819+
820+
assert int(Q(-5,3)) == flint.fmpz(-1)
821+
assert math.floor(Q(-5,3)) == flint.fmpz(-2)
822+
assert math.ceil(Q(-5,3)) == flint.fmpz(-1)
823+
assert math.trunc(Q(-5,3)) == flint.fmpz(-1)
824+
assert round(Q(-5,3)) == -2
825+
826+
assert type(round(Q(5,3))) is flint.fmpz
827+
assert type(round(Q(5,3), 0)) is flint.fmpq
828+
assert type(round(Q(5,3), 1)) is flint.fmpq
829+
assert round(Q(100,3), 2) == Q(3333,100)
830+
assert round(Q(100,3), 0) == Q(33,1)
831+
assert round(Q(100,3), -1) == Q(30,1)
832+
assert round(Q(100,3), -2) == Q(0)
730833

731834
d = {}
732835
d[Q(1,2)] = 3
@@ -1429,6 +1532,17 @@ def test_arb():
14291532
assert A(3) != A(2)
14301533
assert not (A("1.1") == A("1.1"))
14311534

1535+
def test_pickling():
1536+
objects = [
1537+
flint.fmpz(1),
1538+
flint.fmpq(1,2),
1539+
# XXX: Add pickling for everything else
1540+
]
1541+
for obj in objects:
1542+
s = pickle.dumps(obj)
1543+
obj2 = pickle.loads(s)
1544+
assert obj == obj2
1545+
14321546
if __name__ == "__main__":
14331547
sys.stdout.write("test_pyflint..."); test_pyflint(); print("OK")
14341548
sys.stdout.write("test_fmpz..."); test_fmpz(); print("OK")

0 commit comments

Comments
 (0)