19
19
from ._dtypes import int64 as af_int64
20
20
from ._dtypes import supported_dtypes
21
21
from ._dtypes import uint64 as af_uint64
22
- from ._utils import PointerSource , is_number , to_str
22
+ from ._utils import PointerSource , to_str
23
23
24
24
ShapeType = tuple [int , ...]
25
25
_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
@@ -286,25 +286,25 @@ def __radd__(self, other: Array, /) -> Array:
286
286
"""
287
287
Return other + self.
288
288
"""
289
- return _process_c_function (self , other , backend .get ().af_add )
289
+ return _process_c_function (other , self , backend .get ().af_add )
290
290
291
291
def __rsub__ (self , other : Array , / ) -> Array :
292
292
"""
293
293
Return other - self.
294
294
"""
295
- return _process_c_function (self , other , backend .get ().af_sub )
295
+ return _process_c_function (other , self , backend .get ().af_sub )
296
296
297
297
def __rmul__ (self , other : Array , / ) -> Array :
298
298
"""
299
299
Return other * self.
300
300
"""
301
- return _process_c_function (self , other , backend .get ().af_mul )
301
+ return _process_c_function (other , self , backend .get ().af_mul )
302
302
303
303
def __rtruediv__ (self , other : Array , / ) -> Array :
304
304
"""
305
305
Return other / self.
306
306
"""
307
- return _process_c_function (self , other , backend .get ().af_div )
307
+ return _process_c_function (other , self , backend .get ().af_div )
308
308
309
309
def __rfloordiv__ (self , other : Array , / ) -> Array :
310
310
# TODO
@@ -314,13 +314,13 @@ def __rmod__(self, other: Array, /) -> Array:
314
314
"""
315
315
Return other / self.
316
316
"""
317
- return _process_c_function (self , other , backend .get ().af_mod )
317
+ return _process_c_function (other , self , backend .get ().af_mod )
318
318
319
319
def __rpow__ (self , other : Array , / ) -> Array :
320
320
"""
321
321
Return other ** self.
322
322
"""
323
- return _process_c_function (self , other , backend .get ().af_pow )
323
+ return _process_c_function (other , self , backend .get ().af_pow )
324
324
325
325
# Reflected Array Operators
326
326
@@ -334,31 +334,31 @@ def __rand__(self, other: Array, /) -> Array:
334
334
"""
335
335
Return other & self.
336
336
"""
337
- return _process_c_function (self , other , backend .get ().af_bitand )
337
+ return _process_c_function (other , self , backend .get ().af_bitand )
338
338
339
339
def __ror__ (self , other : Array , / ) -> Array :
340
340
"""
341
341
Return other & self.
342
342
"""
343
- return _process_c_function (self , other , backend .get ().af_bitor )
343
+ return _process_c_function (other , self , backend .get ().af_bitor )
344
344
345
345
def __rxor__ (self , other : Array , / ) -> Array :
346
346
"""
347
347
Return other ^ self.
348
348
"""
349
- return _process_c_function (self , other , backend .get ().af_bitxor )
349
+ return _process_c_function (other , self , backend .get ().af_bitxor )
350
350
351
351
def __rlshift__ (self , other : Array , / ) -> Array :
352
352
"""
353
353
Return other << self.
354
354
"""
355
- return _process_c_function (self , other , backend .get ().af_bitshiftl )
355
+ return _process_c_function (other , self , backend .get ().af_bitshiftl )
356
356
357
357
def __rrshift__ (self , other : Array , / ) -> Array :
358
358
"""
359
359
Return other >> self.
360
360
"""
361
- return _process_c_function (self , other , backend .get ().af_bitshiftr )
361
+ return _process_c_function (other , self , backend .get ().af_bitshiftr )
362
362
363
363
# In-place Arithmetic Operators
364
364
@@ -614,20 +614,32 @@ def _str_to_dtype(value: int) -> Dtype:
614
614
615
615
616
616
def _process_c_function (
617
- target : Array , other : int | float | bool | complex | Array , c_function : Any ) -> Array :
617
+ lhs : int | float | bool | complex | Array , rhs : int | float | bool | complex | Array ,
618
+ c_function : Any ) -> Array :
618
619
out = Array ()
619
620
620
- # TODO discuss the difference between binary_func and binary_funcr
621
- # because implementation looks like exectly the same.
622
- # consider chaging to __iadd__ = __radd__ = __add__ interfce if no difference
623
- if isinstance (other , Array ):
624
- safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
625
- elif is_number (other ):
626
- other_dtype = _implicit_dtype (other , target .dtype )
627
- other_array = _constant_array (other , CShape (* target .shape ), other_dtype )
628
- safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
621
+ if isinstance (lhs , Array ) and isinstance (rhs , Array ):
622
+ lhs_array = lhs .arr
623
+ rhs_array = rhs .arr
624
+
625
+ elif isinstance (lhs , Array ) and isinstance (rhs , int | float | bool | complex ):
626
+ rhs_dtype = _implicit_dtype (rhs , lhs .dtype )
627
+ rhs_constant_array = _constant_array (rhs , CShape (* lhs .shape ), rhs_dtype )
628
+
629
+ lhs_array = lhs .arr
630
+ rhs_array = rhs_constant_array .arr
631
+
632
+ elif isinstance (lhs , int | float | bool | complex ) and isinstance (rhs , Array ):
633
+ lhs_dtype = _implicit_dtype (lhs , rhs .dtype )
634
+ lhs_constant_array = _constant_array (lhs , CShape (* rhs .shape ), lhs_dtype )
635
+
636
+ lhs_array = lhs_constant_array .arr
637
+ rhs_array = rhs .arr
638
+
629
639
else :
630
- raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
640
+ raise TypeError (f"{ type (rhs )} is not supported and can not be passed to C binary function." )
641
+
642
+ safe_call (c_function (ctypes .pointer (out .arr ), lhs_array , rhs_array , _bcast_var ))
631
643
632
644
return out
633
645
0 commit comments