Skip to content

Commit f0f57e8

Browse files
committed
Add arithmetic operators w/o tests
1 parent c13a59f commit f0f57e8

File tree

2 files changed

+114
-53
lines changed

2 files changed

+114
-53
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 113 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,25 @@
33
import array as py_array
44
import ctypes
55
from dataclasses import dataclass
6+
from typing import Any
67

78
from arrayfire import backend, safe_call # TODO refactoring
89
from arrayfire.array import _in_display_dims_limit # TODO refactoring
910

10-
from ._dtypes import CShape, Dtype, c_dim_t, float32, supported_dtypes
11-
from ._utils import Device, PointerSource, to_str
11+
from ._dtypes import CShape, Dtype
12+
from ._dtypes import bool as af_bool
13+
from ._dtypes import c_dim_t
14+
from ._dtypes import complex64 as af_complex64
15+
from ._dtypes import complex128 as af_complex128
16+
from ._dtypes import float32 as af_float32
17+
from ._dtypes import float64 as af_float64
18+
from ._dtypes import int64 as af_int64
19+
from ._dtypes import supported_dtypes
20+
from ._dtypes import uint64 as af_uint64
21+
from ._utils import PointerSource, is_number, to_str
1222

1323
ShapeType = tuple[int, ...]
24+
_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
1425

1526

1627
@dataclass
@@ -40,7 +51,7 @@ def __init__(
4051

4152
if dtype is None:
4253
_no_initial_dtype = True
43-
dtype = float32
54+
dtype = af_float32
4455

4556
if x is None:
4657
if not shape: # shape is None or empty tuple
@@ -134,15 +145,47 @@ def __neg__(self) -> Array:
134145
"""
135146
Return -self
136147
"""
137-
# return 0 - self
138-
raise NotImplementedError
148+
return 0 - self
139149

140150
def __add__(self, other: int | float | Array, /) -> Array:
151+
# TODO discuss either we need to support complex and bool as other input type
141152
"""
142153
Return self + other.
143154
"""
144-
# return _binary_func(self, other, backend.get().af_add) # TODO
145-
raise NotImplementedError
155+
return _process_c_function(self, other, backend.get().af_add)
156+
157+
def __sub__(self, other: int | float | bool | complex | Array, /) -> Array:
158+
"""
159+
Return self - other.
160+
"""
161+
return _process_c_function(self, other, backend.get().af_sub)
162+
163+
def __mul__(self, other: int | float | bool | complex | Array, /) -> Array:
164+
"""
165+
Return self * other.
166+
"""
167+
return _process_c_function(self, other, backend.get().af_mul)
168+
169+
def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
170+
"""
171+
Return self / other.
172+
"""
173+
return _process_c_function(self, other, backend.get().af_div)
174+
175+
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
176+
return NotImplemented
177+
178+
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
179+
"""
180+
Return self % other.
181+
"""
182+
return _process_c_function(self, other, backend.get().af_mod)
183+
184+
def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
185+
"""
186+
Return self ** other.
187+
"""
188+
return _process_c_function(self, other, backend.get().af_pow)
146189

147190
@property
148191
def dtype(self) -> Dtype:
@@ -151,7 +194,7 @@ def dtype(self) -> Dtype:
151194
return _c_api_value_to_dtype(out.value)
152195

153196
@property
154-
def device(self) -> Device:
197+
def device(self) -> Any:
155198
raise NotImplementedError
156199

157200
@property
@@ -232,41 +275,66 @@ def _str_to_dtype(value: int) -> Dtype:
232275

233276
raise TypeError("There is no supported dtype that matches passed dtype typecode.")
234277

235-
# TODO
236-
# def _binary_func(lhs: int | float | Array, rhs: int | float | Array, c_func: Any) -> Array: # TODO replace Any
237-
# out = Array()
238-
# other = rhs
239-
240-
# if is_number(rhs):
241-
# ldims = _fill_dim4_tuple(lhs.shape)
242-
# rty = implicit_dtype(rhs, lhs.type())
243-
# other = Array()
244-
# other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
245-
# elif not isinstance(rhs, Array):
246-
# raise TypeError("Invalid parameter to binary function")
247-
248-
# safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
249-
250-
# return out
251-
252-
253-
# TODO replace candidate below
254-
# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
255-
# assert(isinstance(dims, tuple))
256-
257-
# if (default is not None):
258-
# assert(is_number(default))
259-
260-
# out = [default]*4
261-
262-
# for i, dim in enumerate(dims):
263-
# out[i] = dim
264-
265-
# return tuple(out)
266-
267-
# def _fill_dim4_tuple(shape: ShapeType) -> tuple[int, ...]:
268-
# out = tuple([1 if value is None else value for value in shape])
269-
# if len(out) == 4:
270-
# return out
271278

272-
# return out + (1,)*(4-len(out))
279+
def _process_c_function(
280+
target: Array, other: int | float | bool | complex | Array, c_function: Any) -> Array:
281+
out = Array()
282+
283+
if isinstance(other, Array):
284+
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
285+
elif is_number(other):
286+
target_c_shape = CShape(*target.shape)
287+
other_dtype = _implicit_dtype(other, target.dtype)
288+
other_array = _constant_array(other, target_c_shape, other_dtype)
289+
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
290+
else:
291+
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
292+
293+
return out
294+
295+
296+
def _implicit_dtype(value: int | float | bool | complex, array_dtype: Dtype) -> Dtype:
297+
if isinstance(value, bool):
298+
value_dtype = af_bool
299+
if isinstance(value, int):
300+
value_dtype = af_int64
301+
elif isinstance(value, float):
302+
value_dtype = af_float64
303+
elif isinstance(value, complex):
304+
value_dtype = af_complex128
305+
else:
306+
raise TypeError(f"{type(value)} is not supported and can not be converted to af.Dtype.")
307+
308+
if not (array_dtype == af_float32 or array_dtype == af_complex64):
309+
return value_dtype
310+
311+
if value_dtype == af_float64:
312+
return af_float32
313+
314+
if value_dtype == af_complex128:
315+
return af_complex64
316+
317+
return value_dtype
318+
319+
320+
def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: Dtype) -> Array:
321+
out = Array()
322+
323+
if isinstance(value, complex):
324+
if dtype != af_complex64 and dtype != af_complex128:
325+
dtype = af_complex64
326+
327+
safe_call(backend.get().af_constant_complex(
328+
ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
329+
ctypes.pointer(shape.c_array), dtype))
330+
elif dtype == af_int64:
331+
safe_call(backend.get().af_constant_long(
332+
ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
333+
elif dtype == af_uint64:
334+
safe_call(backend.get().af_constant_ulong(
335+
ctypes.pointer(out.arr), ctypes.c_ulonglong(value.real), 4, ctypes.pointer(shape.c_array)))
336+
else:
337+
safe_call(backend.get().af_constant(
338+
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype))
339+
340+
return out

arrayfire/array_api/_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
import ctypes
22
import enum
33
import numbers
4-
from typing import Any
5-
6-
7-
class Device(enum.Enum):
8-
# HACK. TODO make it real
9-
cpu = "cpu"
10-
gpu = "gpu"
114

125

136
class PointerSource(enum.Enum):
@@ -23,5 +16,5 @@ def to_str(c_str: ctypes.c_char_p) -> str:
2316
return str(c_str.value.decode("utf-8")) # type: ignore[union-attr]
2417

2518

26-
def is_number(number: Any) -> bool:
19+
def is_number(number: int | float | bool | complex) -> bool:
2720
return isinstance(number, numbers.Number)

0 commit comments

Comments
 (0)