Skip to content

Commit 0a8bf5c

Browse files
committed
fix: add __rdunder__ methods for acb_mat and arb_mat
1 parent d3c2519 commit 0a8bf5c

File tree

2 files changed

+121
-114
lines changed

2 files changed

+121
-114
lines changed

src/flint/acb_mat.pyx

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
cdef acb_mat_coerce_operands(x, y):
2-
if typecheck(x, acb_mat):
3-
if isinstance(y, (fmpz_mat, fmpq_mat, arb_mat)):
4-
return x, acb_mat(y)
5-
if isinstance(y, (int, long, float, complex, fmpz, fmpq, arb, acb)):
6-
return x, acb_mat(x.nrows(), x.ncols(), y)
7-
elif typecheck(y, acb_mat):
8-
if isinstance(x, (fmpz_mat, fmpq_mat, arb_mat)):
9-
return acb_mat(x), y
10-
if isinstance(y, (int, long, float, complex, fmpz, fmpq, arb, acb)):
11-
return acb_mat(y.nrows(), y.ncols(), x), y
2+
if isinstance(y, (fmpz_mat, fmpq_mat, arb_mat)):
3+
return x, acb_mat(y)
4+
if isinstance(y, (int, long, float, complex, fmpz, fmpq, arb, acb)):
5+
return x, acb_mat(x.nrows(), x.ncols(), y)
126
return NotImplemented, NotImplemented
137

148
cdef acb_mat_coerce_scalar(x, y):
@@ -226,35 +220,49 @@ cdef class acb_mat(flint_mat):
226220

227221
def __add__(s, t):
228222
cdef long m, n
229-
if type(s) is type(t):
230-
m = (<acb_mat>s).nrows()
231-
n = (<acb_mat>s).ncols()
232-
if m != (<acb_mat>t).nrows() or n != (<acb_mat>t).ncols():
233-
raise ValueError("incompatible shapes for matrix addition")
234-
u = acb_mat.__new__(acb_mat)
235-
acb_mat_init((<acb_mat>u).val, m, n)
236-
acb_mat_add((<acb_mat>u).val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
237-
return u
223+
if not isinstance(t, acb_mat):
224+
s, t = acb_mat_coerce_operands(s, t)
225+
if s is NotImplemented:
226+
return s
227+
return s + t
228+
229+
m = (<acb_mat>s).nrows()
230+
n = (<acb_mat>s).ncols()
231+
if m != (<acb_mat>t).nrows() or n != (<acb_mat>t).ncols():
232+
raise ValueError("incompatible shapes for matrix addition")
233+
u = acb_mat.__new__(acb_mat)
234+
acb_mat_init((<acb_mat>u).val, m, n)
235+
acb_mat_add((<acb_mat>u).val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
236+
return u
237+
238+
def __radd__(s, t):
238239
s, t = acb_mat_coerce_operands(s, t)
239240
if s is NotImplemented:
240241
return s
241-
return s + t
242+
return t + s
242243

243244
def __sub__(s, t):
244245
cdef long m, n
245-
if type(s) is type(t):
246-
m = (<acb_mat>s).nrows()
247-
n = (<acb_mat>s).ncols()
248-
if m != (<acb_mat>t).nrows() or n != (<acb_mat>t).ncols():
249-
raise ValueError("incompatible shapes for matrix addition")
250-
u = acb_mat.__new__(acb_mat)
251-
acb_mat_init((<acb_mat>u).val, m, n)
252-
acb_mat_sub((<acb_mat>u).val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
253-
return u
246+
if not isinstance(t, acb_mat):
247+
s, t = acb_mat_coerce_operands(s, t)
248+
if s is NotImplemented:
249+
return s
250+
return s - t
251+
252+
m = (<acb_mat>s).nrows()
253+
n = (<acb_mat>s).ncols()
254+
if m != (<acb_mat>t).nrows() or n != (<acb_mat>t).ncols():
255+
raise ValueError("incompatible shapes for matrix addition")
256+
u = acb_mat.__new__(acb_mat)
257+
acb_mat_init((<acb_mat>u).val, m, n)
258+
acb_mat_sub((<acb_mat>u).val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
259+
return u
260+
261+
def __rsub__(s, t):
254262
s, t = acb_mat_coerce_operands(s, t)
255263
if s is NotImplemented:
256264
return s
257-
return s - t
265+
return t - s
258266

259267
def _scalar_mul_(s, acb t):
260268
cdef acb_mat u
@@ -265,25 +273,30 @@ cdef class acb_mat(flint_mat):
265273

266274
def __mul__(s, t):
267275
cdef acb_mat u
268-
if type(s) is type(t):
269-
if acb_mat_ncols((<acb_mat>s).val) != acb_mat_nrows((<acb_mat>t).val):
270-
raise ValueError("incompatible shapes for matrix multiplication")
271-
u = acb_mat.__new__(acb_mat)
272-
acb_mat_init(u.val, acb_mat_nrows((<acb_mat>s).val), acb_mat_ncols((<acb_mat>t).val))
273-
acb_mat_mul(u.val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
274-
return u
275-
if typecheck(s, acb_mat):
276+
if not isinstance(t, acb_mat):
276277
c, d = acb_mat_coerce_scalar(s, t)
277278
if c is not NotImplemented:
278279
return c._scalar_mul_(d)
279-
else:
280-
d, c = acb_mat_coerce_scalar(t, s)
281-
if d is not NotImplemented:
282-
return d._scalar_mul_(c)
280+
s, t = acb_mat_coerce_operands(s, t)
281+
if s is NotImplemented:
282+
return s
283+
return s * t
284+
285+
if acb_mat_ncols((<acb_mat>s).val) != acb_mat_nrows((<acb_mat>t).val):
286+
raise ValueError("incompatible shapes for matrix multiplication")
287+
u = acb_mat.__new__(acb_mat)
288+
acb_mat_init(u.val, acb_mat_nrows((<acb_mat>s).val), acb_mat_ncols((<acb_mat>t).val))
289+
acb_mat_mul(u.val, (<acb_mat>s).val, (<acb_mat>t).val, getprec())
290+
return u
291+
292+
def __rmul__(s, t):
293+
c, d = acb_mat_coerce_scalar(s, t)
294+
if c is not NotImplemented:
295+
return c._scalar_mul_(d)
283296
s, t = acb_mat_coerce_operands(s, t)
284297
if s is NotImplemented:
285298
return s
286-
return s * t
299+
return t * s
287300

288301
def _scalar_div_(s, acb t):
289302
cdef acb_mat u
@@ -292,8 +305,7 @@ cdef class acb_mat(flint_mat):
292305
acb_mat_scalar_div_acb(u.val, s.val, t.val, getprec())
293306
return u
294307

295-
@staticmethod
296-
def _div_(s, t):
308+
def __truediv__(s, t):
297309
cdef acb_mat u
298310
if typecheck(s, acb_mat):
299311
s, t = acb_mat_coerce_scalar(s, t)
@@ -302,12 +314,6 @@ cdef class acb_mat(flint_mat):
302314
return s._scalar_div_(t)
303315
return NotImplemented
304316

305-
def __truediv__(s, t):
306-
return acb_mat._div_(s, t)
307-
308-
def __div__(s, t):
309-
return acb_mat._div_(s, t)
310-
311317
def __pow__(s, e, m):
312318
cdef acb_mat u
313319
cdef ulong exp

src/flint/arb_mat.pyx

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
cdef arb_mat_coerce_operands(x, y):
2-
if typecheck(x, arb_mat):
3-
if isinstance(y, (fmpz_mat, fmpq_mat)):
4-
return x, arb_mat(y)
5-
if isinstance(y, (int, long, float, fmpz, fmpq, arb)):
6-
return x, arb_mat(x.nrows(), x.ncols(), y)
7-
if isinstance(y, (complex, acb)):
8-
return acb_mat(x), acb_mat(x.nrows(), x.ncols(), y)
9-
elif typecheck(y, arb_mat):
10-
if isinstance(x, (fmpz_mat, fmpq_mat)):
11-
return arb_mat(x), y
12-
if isinstance(y, (int, long, float, fmpz, fmpq, arb)):
13-
return arb_mat(y.nrows(), y.ncols(), x), y
14-
if isinstance(y, (complex, acb)):
15-
return acb_mat(y.nrows(), y.ncols(), x), acb_mat(y)
2+
if isinstance(y, (fmpz_mat, fmpq_mat)):
3+
return x, arb_mat(y)
4+
if isinstance(y, (int, long, float, fmpz, fmpq, arb)):
5+
return x, arb_mat(x.nrows(), x.ncols(), y)
6+
if isinstance(y, (complex, acb)):
7+
return acb_mat(x), acb_mat(x.nrows(), x.ncols(), y)
168
return NotImplemented, NotImplemented
179

1810
cdef arb_mat_coerce_scalar(x, y):
@@ -216,35 +208,49 @@ cdef class arb_mat(flint_mat):
216208

217209
def __add__(s, t):
218210
cdef long m, n
219-
if type(s) is type(t):
220-
m = (<arb_mat>s).nrows()
221-
n = (<arb_mat>s).ncols()
222-
if m != (<arb_mat>t).nrows() or n != (<arb_mat>t).ncols():
223-
raise ValueError("incompatible shapes for matrix addition")
224-
u = arb_mat.__new__(arb_mat)
225-
arb_mat_init((<arb_mat>u).val, m, n)
226-
arb_mat_add((<arb_mat>u).val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
227-
return u
211+
if not isinstance(t, arb_mat):
212+
s, t = arb_mat_coerce_operands(s, t)
213+
if s is NotImplemented:
214+
return s
215+
return s + t
216+
217+
m = (<arb_mat>s).nrows()
218+
n = (<arb_mat>s).ncols()
219+
if m != (<arb_mat>t).nrows() or n != (<arb_mat>t).ncols():
220+
raise ValueError("incompatible shapes for matrix addition")
221+
u = arb_mat.__new__(arb_mat)
222+
arb_mat_init((<arb_mat>u).val, m, n)
223+
arb_mat_add((<arb_mat>u).val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
224+
return u
225+
226+
def __radd__(s, t):
228227
s, t = arb_mat_coerce_operands(s, t)
229228
if s is NotImplemented:
230229
return s
231-
return s + t
230+
return t + s
232231

233232
def __sub__(s, t):
234233
cdef long m, n
235-
if type(s) is type(t):
236-
m = (<arb_mat>s).nrows()
237-
n = (<arb_mat>s).ncols()
238-
if m != (<arb_mat>t).nrows() or n != (<arb_mat>t).ncols():
239-
raise ValueError("incompatible shapes for matrix addition")
240-
u = arb_mat.__new__(arb_mat)
241-
arb_mat_init((<arb_mat>u).val, m, n)
242-
arb_mat_sub((<arb_mat>u).val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
243-
return u
234+
if not isinstance(t, arb_mat):
235+
s, t = arb_mat_coerce_operands(s, t)
236+
if s is NotImplemented:
237+
return s
238+
return s - t
239+
240+
m = (<arb_mat>s).nrows()
241+
n = (<arb_mat>s).ncols()
242+
if m != (<arb_mat>t).nrows() or n != (<arb_mat>t).ncols():
243+
raise ValueError("incompatible shapes for matrix addition")
244+
u = arb_mat.__new__(arb_mat)
245+
arb_mat_init((<arb_mat>u).val, m, n)
246+
arb_mat_sub((<arb_mat>u).val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
247+
return u
248+
249+
def __rsub__(s, t):
244250
s, t = arb_mat_coerce_operands(s, t)
245251
if s is NotImplemented:
246252
return s
247-
return s - t
253+
return t - s
248254

249255
def _scalar_mul_(s, arb t):
250256
cdef arb_mat u
@@ -255,25 +261,31 @@ cdef class arb_mat(flint_mat):
255261

256262
def __mul__(s, t):
257263
cdef arb_mat u
258-
if type(s) is type(t):
259-
if arb_mat_ncols((<arb_mat>s).val) != arb_mat_nrows((<arb_mat>t).val):
260-
raise ValueError("incompatible shapes for matrix multiplication")
261-
u = arb_mat.__new__(arb_mat)
262-
arb_mat_init(u.val, arb_mat_nrows((<arb_mat>s).val), arb_mat_ncols((<arb_mat>t).val))
263-
arb_mat_mul(u.val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
264-
return u
265-
if typecheck(s, arb_mat):
264+
if not isinstance(t, arb_mat):
266265
c, d = arb_mat_coerce_scalar(s, t)
267266
if c is not NotImplemented:
268267
return c._scalar_mul_(d)
269-
else:
270-
d, c = arb_mat_coerce_scalar(t, s)
271-
if d is not NotImplemented:
272-
return d._scalar_mul_(c)
268+
s, t = arb_mat_coerce_operands(s, t)
269+
if s is NotImplemented:
270+
return s
271+
return s * t
272+
273+
if arb_mat_ncols((<arb_mat>s).val) != arb_mat_nrows((<arb_mat>t).val):
274+
raise ValueError("incompatible shapes for matrix multiplication")
275+
u = arb_mat.__new__(arb_mat)
276+
arb_mat_init(u.val, arb_mat_nrows((<arb_mat>s).val), arb_mat_ncols((<arb_mat>t).val))
277+
arb_mat_mul(u.val, (<arb_mat>s).val, (<arb_mat>t).val, getprec())
278+
return u
279+
280+
def __rmul__(s, t):
281+
cdef arb_mat u
282+
c, d = arb_mat_coerce_scalar(s, t)
283+
if c is not NotImplemented:
284+
return c._scalar_mul_(d)
273285
s, t = arb_mat_coerce_operands(s, t)
274286
if s is NotImplemented:
275287
return s
276-
return s * t
288+
return t * s
277289

278290
def _scalar_div_(s, arb t):
279291
cdef arb_mat u
@@ -282,28 +294,17 @@ cdef class arb_mat(flint_mat):
282294
arb_mat_scalar_div_arb(u.val, s.val, t.val, getprec())
283295
return u
284296

285-
@staticmethod
286-
def _div_(s, t):
287-
cdef arb_mat u
288-
if typecheck(s, arb_mat):
289-
s, t = arb_mat_coerce_scalar(s, t)
290-
if s is NotImplemented:
291-
return s
292-
return s._scalar_div_(t)
293-
return NotImplemented
294-
295297
def __truediv__(s, t):
296-
return arb_mat._div_(s, t)
297-
298-
def __div__(s, t):
299-
return arb_mat._div_(s, t)
298+
cdef arb_mat u
299+
s, t = arb_mat_coerce_scalar(s, t)
300+
if s is NotImplemented:
301+
return s
302+
return s._scalar_div_(t)
300303

301304
def __pow__(s, e, m):
302305
cdef arb_mat u
303306
cdef ulong exp
304307
cdef long n
305-
if not typecheck(s, arb_mat):
306-
return NotImplemented
307308
exp = e
308309
n = arb_mat_nrows((<arb_mat>s).val)
309310
if n != arb_mat_ncols((<arb_mat>s).val):

0 commit comments

Comments
 (0)