Skip to content

Commit 8cef774

Browse files
committed
Fix array init bug. Add __getitem__. Change pytest for active debug mode
1 parent f0f57e8 commit 8cef774

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from dataclasses import dataclass
66
from typing import Any
77

8-
from arrayfire import backend, safe_call # TODO refactoring
9-
from arrayfire.array import _in_display_dims_limit # TODO refactoring
8+
from arrayfire import backend, safe_call # TODO refactor
9+
from arrayfire.algorithm import count # TODO refactor
10+
from arrayfire.array import _get_indices, _in_display_dims_limit # TODO refactor
1011

1112
from ._dtypes import CShape, Dtype
1213
from ._dtypes import bool as af_bool
@@ -37,15 +38,15 @@ class Array:
3738
# arrayfire's __radd__() instead of numpy's __add__()
3839
__array_priority__ = 30
3940

40-
# Initialisation
41-
arr = ctypes.c_void_p(0)
42-
4341
def __init__(
4442
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None, dtype: None | Dtype = None,
4543
pointer_source: PointerSource = PointerSource.host, shape: None | ShapeType = None,
4644
offset: None | ctypes._SimpleCData[int] = None, strides: None | ShapeType = None) -> None:
4745
_no_initial_dtype = False # HACK, FIXME
4846

47+
# Initialise array object
48+
self.arr = ctypes.c_void_p(0)
49+
4950
if isinstance(dtype, str):
5051
dtype = _str_to_dtype(dtype)
5152

@@ -127,7 +128,7 @@ def __str__(self) -> str: # FIXME
127128
if not _in_display_dims_limit(self.shape):
128129
return _metadata_string(self.dtype, self.shape)
129130

130-
return _metadata_string(self.dtype) + self._as_str()
131+
return _metadata_string(self.dtype) + _array_as_str(self)
131132

132133
def __repr__(self) -> str: # FIXME
133134
return _metadata_string(self.dtype, self.shape)
@@ -173,6 +174,7 @@ def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
173174
return _process_c_function(self, other, backend.get().af_div)
174175

175176
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
177+
# TODO
176178
return NotImplemented
177179

178180
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
@@ -187,6 +189,25 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
187189
"""
188190
return _process_c_function(self, other, backend.get().af_pow)
189191

192+
def __matmul__(self, other: Array, /) -> Array:
193+
# TODO
194+
return NotImplemented
195+
196+
def __getitem__(self, key: int | slice | tuple[int | slice] | Array, /) -> Array:
197+
# TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
198+
# TODO: refactor
199+
out = Array()
200+
ndims = self.ndim
201+
202+
if isinstance(key, Array) and key == af_bool.c_api_value:
203+
ndims = 1
204+
if count(key) == 0:
205+
return out
206+
207+
safe_call(backend.get().af_index_gen(
208+
ctypes.pointer(out.arr), self.arr, c_dim_t(ndims), _get_indices(key).pointer))
209+
return out
210+
190211
@property
191212
def dtype(self) -> Dtype:
192213
out = ctypes.c_int()
@@ -234,13 +255,23 @@ def shape(self) -> ShapeType:
234255
ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
235256
return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
236257

237-
def _as_str(self) -> str:
238-
arr_str = ctypes.c_char_p(0)
239-
# FIXME add description to passed arguments
240-
safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", self.arr, 4, True))
241-
py_str = to_str(arr_str)
242-
safe_call(backend.get().af_free_host(arr_str))
243-
return py_str
258+
def scalar(self) -> int | float | bool | complex:
259+
"""
260+
Return the first element of the array
261+
"""
262+
# BUG seg fault on empty array
263+
out = self.dtype.c_type()
264+
safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
265+
return out.value # type: ignore[no-any-return] # FIXME
266+
267+
268+
def _array_as_str(array: Array) -> str:
269+
arr_str = ctypes.c_char_p(0)
270+
# FIXME add description to passed arguments
271+
safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", array.arr, 4, True))
272+
py_str = to_str(arr_str)
273+
safe_call(backend.get().af_free_host(arr_str))
274+
return py_str
244275

245276

246277
def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
@@ -283,9 +314,8 @@ def _process_c_function(
283314
if isinstance(other, Array):
284315
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
285316
elif is_number(other):
286-
target_c_shape = CShape(*target.shape)
287317
other_dtype = _implicit_dtype(other, target.dtype)
288-
other_array = _constant_array(other, target_c_shape, other_dtype)
318+
other_array = _constant_array(other, CShape(*target.shape), other_dtype)
289319
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
290320
else:
291321
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
@@ -326,7 +356,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
326356

327357
safe_call(backend.get().af_constant_complex(
328358
ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
329-
ctypes.pointer(shape.c_array), dtype))
359+
ctypes.pointer(shape.c_array), dtype.c_api_value))
330360
elif dtype == af_int64:
331361
safe_call(backend.get().af_constant_long(
332362
ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
@@ -335,6 +365,6 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
335365
ctypes.pointer(out.arr), ctypes.c_ulonglong(value.real), 4, ctypes.pointer(shape.c_array)))
336366
else:
337367
safe_call(backend.get().af_constant(
338-
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype))
368+
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype.c_api_value))
339369

340370
return out

arrayfire/array_api/pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[pytest]
2-
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --isort ./arrayfire/array_api
2+
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --isort -s ./arrayfire/array_api
33
console_output_style = classic
44
markers = mypy

arrayfire/array_api/tests/test_array_object.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,21 @@ def test_array_from_unsupported_type() -> None:
9393

9494
with pytest.raises(TypeError):
9595
Array({1: 2, 3: 4}) # type: ignore[arg-type]
96+
97+
98+
def test_array_getitem() -> None:
99+
array = Array([1, 2, 3, 4, 5])
100+
101+
int_item = array[2]
102+
assert array.dtype == int_item.dtype
103+
assert int_item.scalar() == 3
104+
105+
# TODO add more tests for different dtypes
106+
107+
108+
# def test_array_sum() -> None: # BUG no element-wise adding
109+
# array = Array([1, 2, 3])
110+
# res = array + 1
111+
# assert res.scalar() == 2
112+
# assert res.scalar() == 3
113+
# assert res.scalar() == 4

0 commit comments

Comments
 (0)