Skip to content

Commit 0afb92e

Browse files
committed
Fix reflected operators bug. Add test coverage for the rest of the arithmetic operators
1 parent fb27e46 commit 0afb92e

File tree

3 files changed

+44
-36
lines changed

3 files changed

+44
-36
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ._dtypes import int64 as af_int64
2020
from ._dtypes import supported_dtypes
2121
from ._dtypes import uint64 as af_uint64
22-
from ._utils import PointerSource, is_number, to_str
22+
from ._utils import PointerSource, to_str
2323

2424
ShapeType = tuple[int, ...]
2525
_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
@@ -286,25 +286,25 @@ def __radd__(self, other: Array, /) -> Array:
286286
"""
287287
Return other + self.
288288
"""
289-
return _process_c_function(self, other, backend.get().af_add)
289+
return _process_c_function(other, self, backend.get().af_add)
290290

291291
def __rsub__(self, other: Array, /) -> Array:
292292
"""
293293
Return other - self.
294294
"""
295-
return _process_c_function(self, other, backend.get().af_sub)
295+
return _process_c_function(other, self, backend.get().af_sub)
296296

297297
def __rmul__(self, other: Array, /) -> Array:
298298
"""
299299
Return other * self.
300300
"""
301-
return _process_c_function(self, other, backend.get().af_mul)
301+
return _process_c_function(other, self, backend.get().af_mul)
302302

303303
def __rtruediv__(self, other: Array, /) -> Array:
304304
"""
305305
Return other / self.
306306
"""
307-
return _process_c_function(self, other, backend.get().af_div)
307+
return _process_c_function(other, self, backend.get().af_div)
308308

309309
def __rfloordiv__(self, other: Array, /) -> Array:
310310
# TODO
@@ -314,13 +314,13 @@ def __rmod__(self, other: Array, /) -> Array:
314314
"""
315315
Return other / self.
316316
"""
317-
return _process_c_function(self, other, backend.get().af_mod)
317+
return _process_c_function(other, self, backend.get().af_mod)
318318

319319
def __rpow__(self, other: Array, /) -> Array:
320320
"""
321321
Return other ** self.
322322
"""
323-
return _process_c_function(self, other, backend.get().af_pow)
323+
return _process_c_function(other, self, backend.get().af_pow)
324324

325325
# Reflected Array Operators
326326

@@ -334,31 +334,31 @@ def __rand__(self, other: Array, /) -> Array:
334334
"""
335335
Return other & self.
336336
"""
337-
return _process_c_function(self, other, backend.get().af_bitand)
337+
return _process_c_function(other, self, backend.get().af_bitand)
338338

339339
def __ror__(self, other: Array, /) -> Array:
340340
"""
341341
Return other & self.
342342
"""
343-
return _process_c_function(self, other, backend.get().af_bitor)
343+
return _process_c_function(other, self, backend.get().af_bitor)
344344

345345
def __rxor__(self, other: Array, /) -> Array:
346346
"""
347347
Return other ^ self.
348348
"""
349-
return _process_c_function(self, other, backend.get().af_bitxor)
349+
return _process_c_function(other, self, backend.get().af_bitxor)
350350

351351
def __rlshift__(self, other: Array, /) -> Array:
352352
"""
353353
Return other << self.
354354
"""
355-
return _process_c_function(self, other, backend.get().af_bitshiftl)
355+
return _process_c_function(other, self, backend.get().af_bitshiftl)
356356

357357
def __rrshift__(self, other: Array, /) -> Array:
358358
"""
359359
Return other >> self.
360360
"""
361-
return _process_c_function(self, other, backend.get().af_bitshiftr)
361+
return _process_c_function(other, self, backend.get().af_bitshiftr)
362362

363363
# In-place Arithmetic Operators
364364

@@ -614,20 +614,32 @@ def _str_to_dtype(value: int) -> Dtype:
614614

615615

616616
def _process_c_function(
617-
target: Array, other: int | float | bool | complex | Array, c_function: Any) -> Array:
617+
lhs: int | float | bool | complex | Array, rhs: int | float | bool | complex | Array,
618+
c_function: Any) -> Array:
618619
out = Array()
619620

620-
# TODO discuss the difference between binary_func and binary_funcr
621-
# because implementation looks like exectly the same.
622-
# consider chaging to __iadd__ = __radd__ = __add__ interfce if no difference
623-
if isinstance(other, Array):
624-
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
625-
elif is_number(other):
626-
other_dtype = _implicit_dtype(other, target.dtype)
627-
other_array = _constant_array(other, CShape(*target.shape), other_dtype)
628-
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
621+
if isinstance(lhs, Array) and isinstance(rhs, Array):
622+
lhs_array = lhs.arr
623+
rhs_array = rhs.arr
624+
625+
elif isinstance(lhs, Array) and isinstance(rhs, int | float | bool | complex):
626+
rhs_dtype = _implicit_dtype(rhs, lhs.dtype)
627+
rhs_constant_array = _constant_array(rhs, CShape(*lhs.shape), rhs_dtype)
628+
629+
lhs_array = lhs.arr
630+
rhs_array = rhs_constant_array.arr
631+
632+
elif isinstance(lhs, int | float | bool | complex) and isinstance(rhs, Array):
633+
lhs_dtype = _implicit_dtype(lhs, rhs.dtype)
634+
lhs_constant_array = _constant_array(lhs, CShape(*rhs.shape), lhs_dtype)
635+
636+
lhs_array = lhs_constant_array.arr
637+
rhs_array = rhs.arr
638+
629639
else:
630-
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
640+
raise TypeError(f"{type(rhs)} is not supported and can not be passed to C binary function.")
641+
642+
safe_call(c_function(ctypes.pointer(out.arr), lhs_array, rhs_array, _bcast_var))
631643

632644
return out
633645

arrayfire/array_api/_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ctypes
22
import enum
3-
import numbers
43

54

65
class PointerSource(enum.Enum):
@@ -14,7 +13,3 @@ class PointerSource(enum.Enum):
1413

1514
def to_str(c_str: ctypes.c_char_p) -> str:
1615
return str(c_str.value.decode("utf-8")) # type: ignore[union-attr]
17-
18-
19-
def is_number(number: int | float | bool | complex) -> bool:
20-
return isinstance(number, numbers.Number)

arrayfire/array_api/tests/test_array_object.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def setup_method(self, method: Any) -> None:
179179
self.tuple = (1, 2, 3)
180180
self.const_str = "15"
181181

182-
def teardown_method(self, method: Any) -> None:
183-
self.array = Array(self.list)
184-
185182
def test_add_int(self) -> None:
186183
res = self.array + self.const_int
187184
assert res[0].scalar() == 3
@@ -220,10 +217,10 @@ def test_add_inplace_and_reflected(self) -> None:
220217

221218
def test_add_raises_type_error(self) -> None:
222219
with pytest.raises(TypeError):
223-
Array([1, 2, 3]) + self.const_str # type: ignore[operator]
220+
self.array + self.const_str # type: ignore[operator]
224221

225222
with pytest.raises(TypeError):
226-
Array([1, 2, 3]) + self.tuple # type: ignore[operator]
223+
self.array + self.tuple # type: ignore[operator]
227224

228225
# Test __sub__, __isub__, __rsub__
229226

@@ -251,9 +248,13 @@ def test_sub_inplace_and_reflected(self) -> None:
251248
ires -= self.const_int
252249
rres = self.const_int - self.array # type: ignore[operator]
253250

254-
assert res[0].scalar() == ires[0].scalar() == rres[0].scalar() == -1
255-
assert res[1].scalar() == ires[1].scalar() == rres[1].scalar() == 0
256-
assert res[2].scalar() == ires[2].scalar() == rres[2].scalar() == 1
251+
assert res[0].scalar() == ires[0].scalar() == -1
252+
assert res[1].scalar() == ires[1].scalar() == 0
253+
assert res[2].scalar() == ires[2].scalar() == 1
254+
255+
assert rres[0].scalar() == 1
256+
assert rres[1].scalar() == 0
257+
assert rres[2].scalar() == -1
257258

258259
assert res.dtype == ires.dtype == rres.dtype
259260
assert res.ndim == ires.ndim == rres.ndim

0 commit comments

Comments
 (0)