Skip to content

Commit cdb7a92

Browse files
committed
Added to_list and to_ctypes_array
1 parent 4187b27 commit cdb7a92

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,20 @@ def __init__(
124124
ctypes.pointer(_cshape.c_array), ctypes.pointer(strides_cshape), dtype.c_api_value,
125125
pointer_source.value))
126126

127-
def __str__(self) -> str: # FIXME
127+
def __str__(self) -> str:
128+
# TODO change the look of array str. E.g., like np.array
128129
if not _in_display_dims_limit(self.shape):
129130
return _metadata_string(self.dtype, self.shape)
130131

131132
return _metadata_string(self.dtype) + _array_as_str(self)
132133

133-
def __repr__(self) -> str: # FIXME
134+
def __repr__(self) -> str:
134135
# return _metadata_string(self.dtype, self.shape)
135136
# TODO change the look of array representation. E.g., like np.array
136137
return _array_as_str(self)
137138

138139
def __len__(self) -> int:
139-
return self.shape[0] if self.shape else 0 # type: ignore[return-value]
140+
return self.shape[0] if self.shape else 0
140141

141142
# Arithmetic Operators
142143

@@ -475,17 +476,17 @@ def T(self) -> Array:
475476
raise NotImplementedError
476477

477478
@property
478-
def size(self) -> None | int:
479+
def size(self) -> int:
479480
# NOTE previously - elements()
480481
out = c_dim_t(0)
481482
safe_call(backend.get().af_get_elements(ctypes.pointer(out), self.arr))
482483
return out.value
483484

484485
@property
485486
def ndim(self) -> int:
486-
nd = ctypes.c_uint(0)
487-
safe_call(backend.get().af_get_numdims(ctypes.pointer(nd), self.arr))
488-
return nd.value
487+
out = ctypes.c_uint(0)
488+
safe_call(backend.get().af_get_numdims(ctypes.pointer(out), self.arr))
489+
return out.value
489490

490491
@property
491492
def shape(self) -> ShapeType:
@@ -510,6 +511,62 @@ def scalar(self) -> int | float | bool | complex:
510511
safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
511512
return out.value # type: ignore[no-any-return] # FIXME
512513

514+
def is_empty(self) -> bool:
515+
"""
516+
Check if the array is empty i.e. it has no elements.
517+
"""
518+
out = ctypes.c_bool()
519+
safe_call(backend.get().af_is_empty(ctypes.pointer(out), self.arr))
520+
return out.value
521+
522+
def to_list(self, row_major: bool = False) -> list: # FIXME return typings
523+
if self.is_empty():
524+
return []
525+
526+
array = _reorder(self) if row_major else self
527+
ctypes_array = _get_ctypes_array(array)
528+
529+
if array.ndim == 1:
530+
return list(ctypes_array)
531+
532+
out = []
533+
for i in range(array.size):
534+
idx = i
535+
sub_list = []
536+
for j in range(array.ndim):
537+
div = array.shape[j]
538+
sub_list.append(idx % div)
539+
idx //= div
540+
out.append(ctypes_array[sub_list[::-1]]) # type: ignore[call-overload] # FIXME
541+
return out
542+
543+
def to_ctype_array(self, row_major: bool = False) -> ctypes.Array:
544+
if self.is_empty():
545+
raise RuntimeError("Can not convert an empty array to ctype.")
546+
547+
array = _reorder(self) if row_major else self
548+
return _get_ctypes_array(array)
549+
550+
551+
def _get_ctypes_array(array: Array) -> ctypes.Array:
552+
c_shape = array.dtype.c_type * array.size
553+
ctypes_array = c_shape()
554+
safe_call(backend.get().af_get_data_ptr(ctypes.pointer(ctypes_array), array.arr))
555+
return ctypes_array
556+
557+
558+
def _reorder(array: Array) -> Array:
559+
"""
560+
Returns a reordered array to help interoperate with row major formats.
561+
"""
562+
if array.ndim == 1:
563+
return array
564+
565+
out = Array()
566+
c_shape = CShape(*(tuple(reversed(range(array.ndim))) + tuple(range(array.ndim, 4))))
567+
safe_call(backend.get().af_reorder(ctypes.pointer(out.arr), array.arr, *c_shape))
568+
return out
569+
513570

514571
def _array_as_str(array: Array) -> str:
515572
arr_str = ctypes.c_char_p(0)

arrayfire/array_api/tests/test_array_object.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from arrayfire.array_api import Array, float32, int16
44
from arrayfire.array_api._dtypes import supported_dtypes
55

6+
# TODO change separated methods with setup and teardown to avoid code duplication
7+
68

79
def test_empty_array() -> None:
810
array = Array()
@@ -105,7 +107,13 @@ def test_array_getitem() -> None:
105107
# TODO add more tests for different dtypes
106108

107109

108-
def test_array_sum() -> None:
110+
def test_array_to_list() -> None:
111+
# TODO add test of to_ctypes_array
112+
assert Array([1, 2, 3]).to_list() == [1, 2, 3]
113+
assert Array().to_list() == []
114+
115+
116+
def test_array_add() -> None:
109117
array = Array([1, 2, 3])
110118
res = array + 1
111119
assert res[0].scalar() == 2
@@ -123,6 +131,11 @@ def test_array_sum() -> None:
123131
assert res[2].scalar() == 12
124132

125133

134+
def test_array_add_raises_type_error() -> None:
135+
with pytest.raises(TypeError):
136+
Array([1, 2, 3]) + "15" # type: ignore[operator]
137+
138+
126139
def test_array_sub() -> None:
127140
array = Array([1, 2, 3])
128141
res = array - 1

0 commit comments

Comments
 (0)