Skip to content

Commit fb27e46

Browse files
committed
Change tests and found bug with reflected operators
1 parent 769c16c commit fb27e46

File tree

2 files changed

+297
-122
lines changed

2 files changed

+297
-122
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Array:
3636
# Setting to such a high value should make sure that arrayfire has priority over
3737
# other classes, ensuring that e.g. numpy.float32(1)*arrayfire.randu(3) is handled by
3838
# arrayfire's __radd__() instead of numpy's __add__()
39-
__array_priority__ = 30
39+
__array_priority__ = 30 # TODO discuss its purpose
4040

4141
def __init__(
4242
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None,
@@ -286,25 +286,25 @@ def __radd__(self, other: Array, /) -> Array:
286286
"""
287287
Return other + self.
288288
"""
289-
return _process_c_function(other, self, backend.get().af_add)
289+
return _process_c_function(self, other, backend.get().af_add)
290290

291291
def __rsub__(self, other: Array, /) -> Array:
292292
"""
293293
Return other - self.
294294
"""
295-
return _process_c_function(other, self, backend.get().af_sub)
295+
return _process_c_function(self, other, backend.get().af_sub)
296296

297297
def __rmul__(self, other: Array, /) -> Array:
298298
"""
299299
Return other * self.
300300
"""
301-
return _process_c_function(other, self, backend.get().af_mul)
301+
return _process_c_function(self, other, backend.get().af_mul)
302302

303303
def __rtruediv__(self, other: Array, /) -> Array:
304304
"""
305305
Return other / self.
306306
"""
307-
return _process_c_function(other, self, backend.get().af_div)
307+
return _process_c_function(self, other, backend.get().af_div)
308308

309309
def __rfloordiv__(self, other: Array, /) -> Array:
310310
# TODO
@@ -314,13 +314,13 @@ def __rmod__(self, other: Array, /) -> Array:
314314
"""
315315
Return other / self.
316316
"""
317-
return _process_c_function(other, self, backend.get().af_mod)
317+
return _process_c_function(self, other, backend.get().af_mod)
318318

319319
def __rpow__(self, other: Array, /) -> Array:
320320
"""
321321
Return other ** self.
322322
"""
323-
return _process_c_function(other, self, backend.get().af_pow)
323+
return _process_c_function(self, other, backend.get().af_pow)
324324

325325
# Reflected Array Operators
326326

@@ -334,31 +334,31 @@ def __rand__(self, other: Array, /) -> Array:
334334
"""
335335
Return other & self.
336336
"""
337-
return _process_c_function(other, self, backend.get().af_bitand)
337+
return _process_c_function(self, other, backend.get().af_bitand)
338338

339339
def __ror__(self, other: Array, /) -> Array:
340340
"""
341341
Return other & self.
342342
"""
343-
return _process_c_function(other, self, backend.get().af_bitor)
343+
return _process_c_function(self, other, backend.get().af_bitor)
344344

345345
def __rxor__(self, other: Array, /) -> Array:
346346
"""
347347
Return other ^ self.
348348
"""
349-
return _process_c_function(other, self, backend.get().af_bitxor)
349+
return _process_c_function(self, other, backend.get().af_bitxor)
350350

351351
def __rlshift__(self, other: Array, /) -> Array:
352352
"""
353353
Return other << self.
354354
"""
355-
return _process_c_function(other, self, backend.get().af_bitshiftl)
355+
return _process_c_function(self, other, backend.get().af_bitshiftl)
356356

357357
def __rrshift__(self, other: Array, /) -> Array:
358358
"""
359359
Return other >> self.
360360
"""
361-
return _process_c_function(other, self, backend.get().af_bitshiftr)
361+
return _process_c_function(self, other, backend.get().af_bitshiftr)
362362

363363
# In-place Arithmetic Operators
364364

@@ -617,6 +617,9 @@ def _process_c_function(
617617
target: Array, other: int | float | bool | complex | Array, c_function: Any) -> Array:
618618
out = Array()
619619

620+
# TODO discuss the difference between binary_func and binary_funcr
621+
# because implementation looks like exectly the same.
622+
# consider chaging to __iadd__ = __radd__ = __add__ interfce if no difference
620623
if isinstance(other, Array):
621624
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
622625
elif is_number(other):

0 commit comments

Comments
 (0)