Skip to content

Commit b14aa91

Browse files
committed
Replace dim4 with CShape
1 parent 75c1d43 commit b14aa91

File tree

5 files changed

+83
-71
lines changed

5 files changed

+83
-71
lines changed

arrayfire/array_api/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@
66
"complex64", "complex128", "bool"]
77

88
from ._array_object import Array
9-
from ._dtypes import (
10-
bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64)
9+
from ._dtypes import bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64

arrayfire/array_api/_array_object.py

Lines changed: 25 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import array as py_array
44
import ctypes
5-
import math
65
from dataclasses import dataclass
76

87
from arrayfire import backend, safe_call # TODO refactoring
98
from arrayfire.array import _in_display_dims_limit # TODO refactoring
109

11-
from ._dtypes import Dtype, c_dim_t, float32, supported_dtypes
10+
from ._dtypes import CShape, Dtype, c_dim_t, float32, supported_dtypes
1211
from ._utils import Device, PointerSource, to_str
1312

1413
ShapeType = tuple[None | int, ...]
@@ -28,7 +27,6 @@ class Array:
2827
__array_priority__ = 30
2928

3029
# Initialisation
31-
_array_buffer = _ArrayBuffer()
3230
arr = ctypes.c_void_p(0)
3331

3432
def __init__(
@@ -46,12 +44,12 @@ def __init__(
4644
if x is None:
4745
if not shape: # shape is None or empty tuple
4846
safe_call(backend.get().af_create_handle(
49-
ctypes.pointer(self.arr), 0, ctypes.pointer(dim4()), dtype.c_api_value))
47+
ctypes.pointer(self.arr), 0, ctypes.pointer(CShape().c_array), dtype.c_api_value))
5048
return
5149

5250
# NOTE: applies inplace changes for self.arr
5351
safe_call(backend.get().af_create_handle(
54-
ctypes.pointer(self.arr), len(shape), ctypes.pointer(dim4(*shape)), dtype.c_api_value))
52+
ctypes.pointer(self.arr), len(shape), ctypes.pointer(CShape(*shape).c_array), dtype.c_api_value))
5553
return
5654

5755
if isinstance(x, Array):
@@ -61,19 +59,16 @@ def __init__(
6159
if isinstance(x, py_array.array):
6260
_type_char = x.typecode
6361
_array_buffer = _ArrayBuffer(*x.buffer_info())
64-
numdims, idims = _get_info(shape, _array_buffer.length)
6562

6663
elif isinstance(x, list):
6764
_array = py_array.array("f", x) # BUG [True, False] -> dtype: f32 # TODO add int and float
6865
_type_char = _array.typecode
6966
_array_buffer = _ArrayBuffer(*_array.buffer_info())
70-
numdims, idims = _get_info(shape, _array_buffer.length)
7167

7268
elif isinstance(x, int) or isinstance(x, ctypes.c_void_p): # TODO
7369
_array_buffer = _ArrayBuffer(x if not isinstance(x, ctypes.c_void_p) else x.value)
74-
numdims, idims = _get_info(shape, _array_buffer.length)
7570

76-
if not math.prod(idims):
71+
if not shape:
7772
raise RuntimeError("Expected to receive the initial shape due to the x being a data pointer.")
7873

7974
if _no_initial_dtype:
@@ -84,34 +79,37 @@ def __init__(
8479
else:
8580
raise TypeError("Passed object x is an object of unsupported class.")
8681

82+
_cshape = _get_cshape(shape, _array_buffer.length)
83+
8784
if not _no_initial_dtype and dtype.typecode != _type_char:
8885
raise TypeError("Can not create array of requested type from input data type")
8986

9087
if not (offset or strides):
9188
if pointer_source == PointerSource.host:
9289
safe_call(backend.get().af_create_array(
93-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), numdims,
94-
ctypes.pointer(dim4(*idims)), dtype.c_api_value))
90+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), _cshape.original_shape,
91+
ctypes.pointer(_cshape.c_array), dtype.c_api_value))
9592
return
9693

9794
safe_call(backend.get().af_device_array(
98-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), numdims,
99-
ctypes.pointer(dim4(*idims)), dtype.c_api_value))
95+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), _cshape.original_shape,
96+
ctypes.pointer(_cshape.c_array), dtype.c_api_value))
10097
return
10198

102-
if offset is None: # TODO
99+
if offset is None:
103100
offset = c_dim_t(0)
104101

105-
if strides is None: # TODO
106-
strides = (1, idims[0], idims[0]*idims[1], idims[0]*idims[1]*idims[2])
102+
if strides is None:
103+
strides = (1, _cshape[0], _cshape[0]*_cshape[1], _cshape[0]*_cshape[1]*_cshape[2])
107104

108105
if len(strides) < 4:
109106
strides += (strides[-1], ) * (4 - len(strides))
110-
strides_dim4 = dim4(*strides)
107+
strides_cshape = CShape(*strides).c_array
111108

112109
safe_call(backend.get().af_create_strided_array(
113-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), offset, numdims,
114-
ctypes.pointer(dim4(*idims)), ctypes.pointer(strides_dim4), dtype.c_api_value, pointer_source.value))
110+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), offset, _cshape.original_shape,
111+
ctypes.pointer(_cshape.c_array), ctypes.pointer(strides_cshape), dtype.c_api_value,
112+
pointer_source.value))
115113

116114
def __str__(self) -> str: # FIXME
117115
if not _in_display_dims_limit(self.shape):
@@ -126,7 +124,7 @@ def __len__(self) -> int:
126124
return self.shape[0] if self.shape else 0 # type: ignore[return-value]
127125

128126
def __pos__(self) -> Array:
129-
"""y
127+
"""
130128
Return +self
131129
"""
132130
return self
@@ -190,8 +188,7 @@ def shape(self) -> ShapeType:
190188
d3 = c_dim_t(0)
191189
safe_call(backend.get().af_get_dims(
192190
ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
193-
dims = (d0.value, d1.value, d2.value, d3.value)
194-
return dims[:self.ndim] # FIXME An array dimension must be None if and only if a dimension is unknown
191+
return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
195192

196193
def _as_str(self) -> str:
197194
arr_str = ctypes.c_char_p(0)
@@ -201,30 +198,6 @@ def _as_str(self) -> str:
201198
safe_call(backend.get().af_free_host(arr_str))
202199
return py_str
203200

204-
# def _get_metadata_str(self, show_dims: bool = True) -> str:
205-
# return (
206-
# "arrayfire.Array()\n"
207-
# f"Type: {self.dtype.typename}\n"
208-
# f"Dims: {str(self._dims) if show_dims else ''}")
209-
210-
# @property
211-
# def dtype(self) -> ...:
212-
# dty = ctypes.c_int()
213-
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
214-
215-
# @safe_call
216-
# def backend()
217-
# ...
218-
219-
# @backend(safe=True)
220-
# def af_get_type(arr) -> ...:
221-
# dty = ctypes.c_int()
222-
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
223-
# return dty
224-
225-
# def new_dtype():
226-
# return af_get_type(self.arr)
227-
228201

229202
def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
230203
return (
@@ -233,20 +206,14 @@ def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
233206
f"Dims: {str(dims) if dims else ''}")
234207

235208

236-
def _get_info(shape: None | tuple[int], buffer_length: int) -> tuple[int, list[int]]:
237-
# TODO refactor
209+
def _get_cshape(shape: None | tuple[int], buffer_length: int) -> CShape:
238210
if shape:
239-
numdims = len(shape)
240-
idims = [1]*4
241-
for i in range(numdims):
242-
idims[i] = shape[i]
243-
elif (buffer_length != 0):
244-
idims = [buffer_length, 1, 1, 1]
245-
numdims = 1
246-
else:
247-
raise RuntimeError("Invalid size")
211+
return CShape(*shape)
212+
213+
if buffer_length != 0:
214+
return CShape(buffer_length)
248215

249-
return numdims, idims
216+
raise RuntimeError("Shape and buffer length are size invalid.")
250217

251218

252219
def _c_api_value_to_dtype(value: int) -> Dtype:
@@ -282,16 +249,6 @@ def _str_to_dtype(value: int) -> Dtype:
282249
# return out
283250

284251

285-
def dim4(d0: int = 1, d1: int = 1, d2: int = 1, d3: int = 1): # type: ignore # FIXME
286-
c_dim4 = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
287-
out = c_dim4(1, 1, 1, 1)
288-
289-
for i, dim in enumerate((d0, d1, d2, d3)):
290-
if dim is not None:
291-
out[i] = c_dim_t(dim)
292-
293-
return out
294-
295252
# TODO replace candidate below
296253
# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
297254
# assert(isinstance(dims, tuple))

arrayfire/array_api/_dtypes.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import ctypes
24
from dataclasses import dataclass
35
from typing import Type
@@ -31,6 +33,39 @@ class Dtype:
3133
bool = Dtype("b", ctypes.c_bool, "bool", 4)
3234

3335
supported_dtypes = [
34-
# int8,
3536
int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, bool
3637
]
38+
39+
40+
class CShape(tuple):
41+
def __new__(cls, *args: int) -> CShape:
42+
cls.original_shape = len(args)
43+
return tuple.__new__(cls, args)
44+
45+
def __init__(self, x1: int = 1, x2: int = 1, x3: int = 1, x4: int = 1) -> None:
46+
self.x1 = x1
47+
self.x2 = x2
48+
self.x3 = x3
49+
self.x4 = x4
50+
51+
def __repr__(self) -> str:
52+
return f"{self.__class__.__name__}{self.x1, self.x2, self.x3, self.x4}"
53+
54+
@property
55+
def c_array(self): # type: ignore[no-untyped-def]
56+
c_shape = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
57+
return c_shape(c_dim_t(self.x1), c_dim_t(self.x2), c_dim_t(self.x3), c_dim_t(self.x4))
58+
59+
60+
# @safe_call
61+
# def backend()
62+
# ...
63+
64+
# @backend(safe=True)
65+
# def af_get_type(arr) -> ...:
66+
# dty = ctypes.c_int()
67+
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
68+
# return dty
69+
70+
# def new_dtype():
71+
# return af_get_type(self.arr)

arrayfire/array_api/pytest.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[pytest]
2+
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --mypy --isort ./arrayfire/array_api
3+
console_output_style = classic
4+
markers = mypy

arrayfire/array_api/tests/test_array.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from arrayfire.array_api import Array, float32
24

35

@@ -9,3 +11,18 @@ def test_empty_array() -> None:
911
assert array.size == 0
1012
assert array.shape == ()
1113
assert len(array) == 0
14+
15+
16+
def test_array_from_1d_list() -> None:
17+
array = Array([1, 2, 3])
18+
19+
assert array.dtype == float32
20+
assert array.ndim == 1
21+
assert array.size == 3
22+
assert array.shape == (3,)
23+
assert len(array) == 3
24+
25+
26+
def test_array_from_2d_list() -> None:
27+
with pytest.raises(TypeError):
28+
Array([[1, 2, 3], [1, 2, 3]])

0 commit comments

Comments
 (0)