Skip to content

Commit 2d4cfa9

Browse files
committed
Add temporary workaround for fmpz_or bug
1 parent e4079c1 commit 2d4cfa9

File tree

2 files changed

+45
-35
lines changed

2 files changed

+45
-35
lines changed

src/flint/fmpz.pyx

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -487,29 +487,45 @@ cdef class fmpz(flint_scalar):
487487
fmpz_clear(tval)
488488
return u
489489

490+
# This is the correct code when fmpz_or is fixed (in flint 3.0.0)
491+
#
492+
#def __or__(self, other):
493+
# cdef fmpz_struct tval[1]
494+
# cdef int ttype = FMPZ_UNKNOWN
495+
# ttype = fmpz_set_any_ref(tval, other)
496+
# if ttype == FMPZ_UNKNOWN:
497+
# return NotImplemented
498+
# u = fmpz.__new__(fmpz)
499+
# fmpz_or((<fmpz>u).val, self.val, tval)
500+
# if ttype == FMPZ_TMP:
501+
# fmpz_clear(tval)
502+
# return u
503+
#
504+
#def __ror__(self, other):
505+
# cdef fmpz_struct tval[1]
506+
# cdef int ttype = FMPZ_UNKNOWN
507+
# ttype = fmpz_set_any_ref(tval, other)
508+
# if ttype == FMPZ_UNKNOWN:
509+
# return NotImplemented
510+
# u = fmpz.__new__(fmpz)
511+
# fmpz_or((<fmpz>u).val, tval, self.val)
512+
# if ttype == FMPZ_TMP:
513+
# fmpz_clear(tval)
514+
# return u
515+
490516
def __or__(self, other):
491-
cdef fmpz_struct tval[1]
492-
cdef int ttype = FMPZ_UNKNOWN
493-
ttype = fmpz_set_any_ref(tval, other)
494-
if ttype == FMPZ_UNKNOWN:
517+
if typecheck(other, fmpz):
518+
other = int(other)
519+
if typecheck(other, int):
520+
return fmpz(int(self) | other)
521+
else:
495522
return NotImplemented
496-
u = fmpz.__new__(fmpz)
497-
fmpz_or((<fmpz>u).val, self.val, tval)
498-
if ttype == FMPZ_TMP:
499-
fmpz_clear(tval)
500-
return u
501523

502524
def __ror__(self, other):
503-
cdef fmpz_struct tval[1]
504-
cdef int ttype = FMPZ_UNKNOWN
505-
ttype = fmpz_set_any_ref(tval, other)
506-
if ttype == FMPZ_UNKNOWN:
525+
if typecheck(other, int):
526+
return fmpz(other | int(self))
527+
else:
507528
return NotImplemented
508-
u = fmpz.__new__(fmpz)
509-
fmpz_or((<fmpz>u).val, tval, self.val)
510-
if ttype == FMPZ_TMP:
511-
fmpz_clear(tval)
512-
return u
513529

514530
def __xor__(self, other):
515531
cdef fmpz_struct tval[1]

test/test.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,22 @@ def test_fmpz():
8585
for t in L:
8686
for ltype in (flint.fmpz, int, long):
8787
for rtype in (flint.fmpz, int, long):
88+
89+
assert (ltype(s) == rtype(t)) == (s == t)
90+
assert (ltype(s) != rtype(t)) == (s != t)
91+
assert (ltype(s) < rtype(t)) == (s < t)
92+
assert (ltype(s) <= rtype(t)) == (s <= t)
93+
assert (ltype(s) > rtype(t)) == (s > t)
94+
assert (ltype(s) >= rtype(t)) == (s >= t)
95+
8896
assert ltype(s) + rtype(t) == s + t
8997
assert ltype(s) - rtype(t) == s - t
9098
assert ltype(s) * rtype(t) == s * t
9199
assert ltype(s) & rtype(t) == s & t
100+
assert ltype(s) | rtype(t) == s | t
101+
assert ltype(s) ^ rtype(t) == s ^ t
92102
assert ~ltype(s) == ~s
93103

94-
# XXX: Some sort of bug here means that the following fails
95-
# for values of s and t bigger than the word size.
96-
if abs(s) < 2**62 and abs(t) < 2**62:
97-
assert ltype(s) | rtype(t) == s | t
98-
else:
99-
# This still works so somehow the internal representation
100-
# of the fmpz made by fmpz_or is such that fmpz_equal
101-
# returns false even though the values look the same.
102-
assert str(ltype(s) | rtype(t)) == str(s | t)
103-
104-
assert ltype(s) ^ rtype(t) == s ^ t
105104
if t == 0:
106105
assert raises(lambda: ltype(s) // rtype(t), ZeroDivisionError)
107106
assert raises(lambda: ltype(s) % rtype(t), ZeroDivisionError)
@@ -110,12 +109,7 @@ def test_fmpz():
110109
assert ltype(s) // rtype(t) == s // t
111110
assert ltype(s) % rtype(t) == s % t
112111
assert divmod(ltype(s), rtype(t)) == divmod(s, t)
113-
assert (ltype(s) == rtype(t)) == (s == t)
114-
assert (ltype(s) != rtype(t)) == (s != t)
115-
assert (ltype(s) < rtype(t)) == (s < t)
116-
assert (ltype(s) <= rtype(t)) == (s <= t)
117-
assert (ltype(s) > rtype(t)) == (s > t)
118-
assert (ltype(s) >= rtype(t)) == (s >= t)
112+
119113
if 0 <= t < 10:
120114
assert (ltype(s) ** rtype(t)) == (s ** t)
121115
assert ltype(s) << rtype(t) == s << t

0 commit comments

Comments
 (0)