@@ -190,18 +190,29 @@ cdef class nmod_mat:
190190 cdef nmod_mat r
191191 cdef nmod_mat_struct * sv
192192 cdef nmod_mat_struct * tv
193- if typecheck(s, nmod_mat):
194- sv = & (< nmod_mat> s).val[0 ]
195- t = any_as_nmod_mat(t, sv.mod)
196- if t is NotImplemented :
197- return t
198- tv = & (< nmod_mat> t).val[0 ]
199- else :
200- tv = & (< nmod_mat> t).val[0 ]
201- s = any_as_nmod_mat(s, tv.mod)
202- if s is NotImplemented :
203- return s
204- sv = & (< nmod_mat> s).val[0 ]
193+ sv = & (< nmod_mat> s).val[0 ]
194+ t = any_as_nmod_mat(t, sv.mod)
195+ if t is NotImplemented :
196+ return t
197+ tv = & (< nmod_mat> t).val[0 ]
198+ if sv.mod.n != tv.mod.n:
199+ raise ValueError (" cannot add nmod_mats with different moduli" )
200+ if sv.r != tv.r or sv.c != tv.c:
201+ raise ValueError (" incompatible shapes for matrix addition" )
202+ r = nmod_mat.__new__ (nmod_mat)
203+ nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n)
204+ nmod_mat_add(r.val, sv, tv)
205+ return r
206+
207+ def __radd__ (s , t ):
208+ cdef nmod_mat r
209+ cdef nmod_mat_struct * sv
210+ cdef nmod_mat_struct * tv
211+ sv = & (< nmod_mat> s).val[0 ]
212+ t = any_as_nmod_mat(t, sv.mod)
213+ if t is NotImplemented :
214+ return t
215+ tv = & (< nmod_mat> t).val[0 ]
205216 if sv.mod.n != tv.mod.n:
206217 raise ValueError (" cannot add nmod_mats with different moduli" )
207218 if sv.r != tv.r or sv.c != tv.c:
@@ -215,18 +226,11 @@ cdef class nmod_mat:
215226 cdef nmod_mat r
216227 cdef nmod_mat_struct * sv
217228 cdef nmod_mat_struct * tv
218- if typecheck(s, nmod_mat):
219- sv = & (< nmod_mat> s).val[0 ]
220- t = any_as_nmod_mat(t, sv.mod)
221- if t is NotImplemented :
222- return t
223- tv = & (< nmod_mat> t).val[0 ]
224- else :
225- tv = & (< nmod_mat> t).val[0 ]
226- s = any_as_nmod_mat(s, tv.mod)
227- if s is NotImplemented :
228- return s
229- sv = & (< nmod_mat> s).val[0 ]
229+ sv = & (< nmod_mat> s).val[0 ]
230+ t = any_as_nmod_mat(t, sv.mod)
231+ if t is NotImplemented :
232+ return t
233+ tv = & (< nmod_mat> t).val[0 ]
230234 if sv.mod.n != tv.mod.n:
231235 raise ValueError (" cannot subtract nmod_mats with different moduli" )
232236 if sv.r != tv.r or sv.c != tv.c:
@@ -236,6 +240,24 @@ cdef class nmod_mat:
236240 nmod_mat_sub(r.val, sv, tv)
237241 return r
238242
243+ def __rsub__ (s , t ):
244+ cdef nmod_mat r
245+ cdef nmod_mat_struct * sv
246+ cdef nmod_mat_struct * tv
247+ sv = & (< nmod_mat> s).val[0 ]
248+ t = any_as_nmod_mat(t, sv.mod)
249+ if t is NotImplemented :
250+ return t
251+ tv = & (< nmod_mat> t).val[0 ]
252+ if sv.mod.n != tv.mod.n:
253+ raise ValueError (" cannot subtract nmod_mats with different moduli" )
254+ if sv.r != tv.r or sv.c != tv.c:
255+ raise ValueError (" incompatible shapes for matrix subtraction" )
256+ r = nmod_mat.__new__ (nmod_mat)
257+ nmod_mat_init(r.val, sv.r, sv.c, sv.mod.n)
258+ nmod_mat_sub(r.val, tv, sv)
259+ return r
260+
239261 cdef __mul_nmod(self , mp_limb_t c):
240262 cdef nmod_mat r = nmod_mat.__new__ (nmod_mat)
241263 nmod_mat_init(r.val, self .val.r, self .val.c, self .val.mod.n)
@@ -247,22 +269,13 @@ cdef class nmod_mat:
247269 cdef nmod_mat_struct * sv
248270 cdef nmod_mat_struct * tv
249271 cdef mp_limb_t c
250- if typecheck(s, nmod_mat):
251- sv = & (< nmod_mat> s).val[0 ]
252- u = any_as_nmod_mat(t, sv.mod)
253- if u is NotImplemented :
254- if any_as_nmod(& c, t, sv.mod):
255- return (< nmod_mat> s).__mul_nmod(c)
256- return NotImplemented
257- tv = & (< nmod_mat> u).val[0 ]
258- else :
259- tv = & (< nmod_mat> t).val[0 ]
260- u = any_as_nmod_mat(s, tv.mod)
261- if u is NotImplemented :
262- if any_as_nmod(& c, s, tv.mod):
263- return (< nmod_mat> t).__mul_nmod(c)
264- return NotImplemented
265- sv = & (< nmod_mat> u).val[0 ]
272+ sv = & (< nmod_mat> s).val[0 ]
273+ u = any_as_nmod_mat(t, sv.mod)
274+ if u is NotImplemented :
275+ if any_as_nmod(& c, t, sv.mod):
276+ return (< nmod_mat> s).__mul_nmod(c)
277+ return NotImplemented
278+ tv = & (< nmod_mat> u).val[0 ]
266279 if sv.mod.n != tv.mod.n:
267280 raise ValueError (" cannot multiply nmod_mats with different moduli" )
268281 if sv.c != tv.r:
@@ -272,6 +285,14 @@ cdef class nmod_mat:
272285 nmod_mat_mul(r.val, sv, tv)
273286 return r
274287
288+ def __rmul__ (s , t ):
289+ cdef nmod_mat_struct * sv
290+ cdef mp_limb_t c
291+ sv = & (< nmod_mat> s).val[0 ]
292+ if any_as_nmod(& c, t, sv.mod):
293+ return (< nmod_mat> s).__mul_nmod(c)
294+ return NotImplemented
295+
275296 @staticmethod
276297 def _div_ (nmod_mat s , t ):
277298 cdef mp_limb_t v
0 commit comments