Skip to content

Commit 1556939

Browse files
committed
fix: add __rdunder__ fmpz_mat, fmpq_mat and nmp_mat
1 parent 978696a commit 1556939

File tree

4 files changed

+97
-79
lines changed

4 files changed

+97
-79
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
define_macros = []
48-
compiler_directives = {'language_level':3}
48+
compiler_directives = {'language_level':3, 'binding':False}
4949

5050

5151
# Enable coverage tracing

src/flint/fmpq_mat.pyx

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ cdef class fmpq_mat(flint_mat):
145145
cdef fmpq_mat u
146146
cdef fmpq_mat_struct *sval
147147
cdef fmpq_mat_struct *tval
148-
s = any_as_fmpq_mat(s)
149-
if s is NotImplemented:
150-
return s
151148
t = any_as_fmpq_mat(t)
152149
if t is NotImplemented:
153150
return t
@@ -165,9 +162,6 @@ cdef class fmpq_mat(flint_mat):
165162
cdef fmpq_mat u
166163
cdef fmpq_mat_struct *sval
167164
cdef fmpq_mat_struct *tval
168-
s = any_as_fmpq_mat(s)
169-
if s is NotImplemented:
170-
return s
171165
t = any_as_fmpq_mat(t)
172166
if t is NotImplemented:
173167
return t
@@ -225,30 +219,31 @@ cdef class fmpq_mat(flint_mat):
225219

226220
def __mul__(s, t):
227221
cdef fmpz_mat u
228-
if typecheck(s, fmpq_mat):
229-
if typecheck(t, fmpq_mat):
230-
return (<fmpq_mat>s).__mul_fmpq_mat(t)
231-
elif typecheck(t, fmpz_mat):
232-
return (<fmpq_mat>s).__mul_fmpz_mat(t)
233-
else:
234-
c = any_as_fmpz(t)
235-
if c is not NotImplemented:
236-
return (<fmpq_mat>s).__mul_fmpz(c)
237-
c = any_as_fmpq(t)
238-
if c is not NotImplemented:
239-
return (<fmpq_mat>s).__mul_fmpq(c)
240-
return NotImplemented
222+
if typecheck(t, fmpq_mat):
223+
return (<fmpq_mat>s).__mul_fmpq_mat(t)
224+
elif typecheck(t, fmpz_mat):
225+
return (<fmpq_mat>s).__mul_fmpz_mat(t)
241226
else:
242-
if typecheck(s, fmpz_mat):
243-
return (<fmpq_mat>t).__mul_r_fmpz_mat(s)
244-
else:
245-
c = any_as_fmpz(s)
246-
if c is not NotImplemented:
247-
return (<fmpq_mat>t).__mul_fmpz(c)
248-
c = any_as_fmpq(s)
249-
if c is not NotImplemented:
250-
return (<fmpq_mat>t).__mul_fmpq(c)
251-
return NotImplemented
227+
c = any_as_fmpz(t)
228+
if c is not NotImplemented:
229+
return (<fmpq_mat>s).__mul_fmpz(c)
230+
c = any_as_fmpq(t)
231+
if c is not NotImplemented:
232+
return (<fmpq_mat>s).__mul_fmpq(c)
233+
return NotImplemented
234+
235+
def __rmul__(s, t):
236+
cdef fmpz_mat u
237+
if typecheck(t, fmpz_mat):
238+
return (<fmpq_mat>s).__mul_r_fmpz_mat(t)
239+
else:
240+
c = any_as_fmpz(t)
241+
if c is not NotImplemented:
242+
return (<fmpq_mat>s).__mul_fmpz(c)
243+
c = any_as_fmpq(t)
244+
if c is not NotImplemented:
245+
return (<fmpq_mat>s).__mul_fmpq(c)
246+
return NotImplemented
252247

253248
@staticmethod
254249
def _div_(fmpq_mat s, t):

src/flint/fmpz_mat.pyx

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ cdef class fmpz_mat(flint_mat):
190190
cdef fmpz_mat u
191191
cdef fmpz_mat_struct *sval
192192
cdef fmpz_mat_struct *tval
193-
sm = any_as_fmpz_mat(s)
194-
if sm is NotImplemented:
195-
return sm
196193
tm = any_as_fmpz_mat(t)
197194
if tm is NotImplemented:
198195
return tm
@@ -210,9 +207,6 @@ cdef class fmpz_mat(flint_mat):
210207
cdef fmpz_mat u
211208
cdef fmpz_mat_struct *sval
212209
cdef fmpz_mat_struct *tval
213-
sm = any_as_fmpz_mat(s)
214-
if sm is NotImplemented:
215-
return sm
216210
tm = any_as_fmpz_mat(t)
217211
if tm is NotImplemented:
218212
return tm
@@ -238,7 +232,7 @@ cdef class fmpz_mat(flint_mat):
238232
cdef fmpz_mat_struct *sval
239233
cdef fmpz_mat_struct *tval
240234
cdef int ttype
241-
if typecheck(s, fmpz_mat) and typecheck(t, fmpz_mat):
235+
if typecheck(t, fmpz_mat):
242236
sval = &(<fmpz_mat>s).val[0]
243237
tval = &(<fmpz_mat>t).val[0]
244238
if fmpz_mat_ncols(sval) != fmpz_mat_nrows(tval):
@@ -248,8 +242,6 @@ cdef class fmpz_mat(flint_mat):
248242
fmpz_mat_mul(u.val, sval, tval)
249243
return u
250244
else:
251-
if typecheck(t, fmpz_mat):
252-
s, t = t, s
253245
c = any_as_fmpz(t)
254246
if c is not NotImplemented:
255247
return (<fmpz_mat>s).__mul_fmpz(c)
@@ -259,6 +251,16 @@ cdef class fmpz_mat(flint_mat):
259251
return fmpq_mat(s) * t
260252
return NotImplemented
261253

254+
def __rmul__(s, t):
255+
c = any_as_fmpz(t)
256+
if c is not NotImplemented:
257+
return (<fmpz_mat>s).__mul_fmpz(c)
258+
c = any_as_fmpq(t)
259+
if c is not NotImplemented:
260+
# XXX: improve this
261+
return fmpq_mat(s) * t
262+
return NotImplemented
263+
262264
@staticmethod
263265
def _div_(fmpz_mat s, t):
264266
return s * (1 / fmpq(t))

src/flint/nmod_mat.pyx

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,29 @@ cdef class nmod_mat:
190190
cdef nmod_mat r
191191
cdef nmod_mat_struct *sv
192192
cdef nmod_mat_struct *tv
193-
if typecheck(s, nmod_mat):
194-
sv = &(<nmod_mat>s).val[0]
195-
t = any_as_nmod_mat(t, sv.mod)
196-
if t is NotImplemented:
197-
return t
198-
tv = &(<nmod_mat>t).val[0]
199-
else:
200-
tv = &(<nmod_mat>t).val[0]
201-
s = any_as_nmod_mat(s, tv.mod)
202-
if s is NotImplemented:
203-
return s
204-
sv = &(<nmod_mat>s).val[0]
193+
sv = &(<nmod_mat>s).val[0]
194+
t = any_as_nmod_mat(t, sv.mod)
195+
if t is NotImplemented:
196+
return t
197+
tv = &(<nmod_mat>t).val[0]
198+
if sv.mod.n != tv.mod.n:
199+
raise ValueError("cannot add nmod_mats with different moduli")
200+
if sv.r != tv.r or sv.c != tv.c:
201+
raise ValueError("incompatible shapes for matrix addition")
202+
r = nmod_mat.__new__(nmod_mat)
203+
nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n)
204+
nmod_mat_add(r.val, sv, tv)
205+
return r
206+
207+
def __radd__(s, t):
208+
cdef nmod_mat r
209+
cdef nmod_mat_struct *sv
210+
cdef nmod_mat_struct *tv
211+
sv = &(<nmod_mat>s).val[0]
212+
t = any_as_nmod_mat(t, sv.mod)
213+
if t is NotImplemented:
214+
return t
215+
tv = &(<nmod_mat>t).val[0]
205216
if sv.mod.n != tv.mod.n:
206217
raise ValueError("cannot add nmod_mats with different moduli")
207218
if sv.r != tv.r or sv.c != tv.c:
@@ -215,18 +226,11 @@ cdef class nmod_mat:
215226
cdef nmod_mat r
216227
cdef nmod_mat_struct *sv
217228
cdef nmod_mat_struct *tv
218-
if typecheck(s, nmod_mat):
219-
sv = &(<nmod_mat>s).val[0]
220-
t = any_as_nmod_mat(t, sv.mod)
221-
if t is NotImplemented:
222-
return t
223-
tv = &(<nmod_mat>t).val[0]
224-
else:
225-
tv = &(<nmod_mat>t).val[0]
226-
s = any_as_nmod_mat(s, tv.mod)
227-
if s is NotImplemented:
228-
return s
229-
sv = &(<nmod_mat>s).val[0]
229+
sv = &(<nmod_mat>s).val[0]
230+
t = any_as_nmod_mat(t, sv.mod)
231+
if t is NotImplemented:
232+
return t
233+
tv = &(<nmod_mat>t).val[0]
230234
if sv.mod.n != tv.mod.n:
231235
raise ValueError("cannot subtract nmod_mats with different moduli")
232236
if sv.r != tv.r or sv.c != tv.c:
@@ -236,6 +240,24 @@ cdef class nmod_mat:
236240
nmod_mat_sub(r.val, sv, tv)
237241
return r
238242

243+
def __rsub__(s, t):
244+
cdef nmod_mat r
245+
cdef nmod_mat_struct *sv
246+
cdef nmod_mat_struct *tv
247+
sv = &(<nmod_mat>s).val[0]
248+
t = any_as_nmod_mat(t, sv.mod)
249+
if t is NotImplemented:
250+
return t
251+
tv = &(<nmod_mat>t).val[0]
252+
if sv.mod.n != tv.mod.n:
253+
raise ValueError("cannot subtract nmod_mats with different moduli")
254+
if sv.r != tv.r or sv.c != tv.c:
255+
raise ValueError("incompatible shapes for matrix subtraction")
256+
r = nmod_mat.__new__(nmod_mat)
257+
nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n)
258+
nmod_mat_sub(r.val, tv, sv)
259+
return r
260+
239261
cdef __mul_nmod(self, mp_limb_t c):
240262
cdef nmod_mat r = nmod_mat.__new__(nmod_mat)
241263
nmod_mat_init(r.val, self.val.r, self.val.c, self.val.mod.n)
@@ -247,22 +269,13 @@ cdef class nmod_mat:
247269
cdef nmod_mat_struct *sv
248270
cdef nmod_mat_struct *tv
249271
cdef mp_limb_t c
250-
if typecheck(s, nmod_mat):
251-
sv = &(<nmod_mat>s).val[0]
252-
u = any_as_nmod_mat(t, sv.mod)
253-
if u is NotImplemented:
254-
if any_as_nmod(&c, t, sv.mod):
255-
return (<nmod_mat>s).__mul_nmod(c)
256-
return NotImplemented
257-
tv = &(<nmod_mat>u).val[0]
258-
else:
259-
tv = &(<nmod_mat>t).val[0]
260-
u = any_as_nmod_mat(s, tv.mod)
261-
if u is NotImplemented:
262-
if any_as_nmod(&c, s, tv.mod):
263-
return (<nmod_mat>t).__mul_nmod(c)
264-
return NotImplemented
265-
sv = &(<nmod_mat>u).val[0]
272+
sv = &(<nmod_mat>s).val[0]
273+
u = any_as_nmod_mat(t, sv.mod)
274+
if u is NotImplemented:
275+
if any_as_nmod(&c, t, sv.mod):
276+
return (<nmod_mat>s).__mul_nmod(c)
277+
return NotImplemented
278+
tv = &(<nmod_mat>u).val[0]
266279
if sv.mod.n != tv.mod.n:
267280
raise ValueError("cannot multiply nmod_mats with different moduli")
268281
if sv.c != tv.r:
@@ -272,6 +285,14 @@ cdef class nmod_mat:
272285
nmod_mat_mul(r.val, sv, tv)
273286
return r
274287

288+
def __rmul__(s, t):
289+
cdef nmod_mat_struct *sv
290+
cdef mp_limb_t c
291+
sv = &(<nmod_mat>s).val[0]
292+
if any_as_nmod(&c, t, sv.mod):
293+
return (<nmod_mat>s).__mul_nmod(c)
294+
return NotImplemented
295+
275296
@staticmethod
276297
def _div_(nmod_mat s, t):
277298
cdef mp_limb_t v

0 commit comments

Comments
 (0)