Skip to content

Commit 769c16c

Browse files
committed
Fix typing in array object. Add tests
1 parent 9c0435a commit 769c16c

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ class Array:
3939
__array_priority__ = 30
4040

4141
def __init__(
42-
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None, dtype: None | Dtype = None,
43-
pointer_source: PointerSource = PointerSource.host, shape: None | ShapeType = None,
44-
offset: None | ctypes._SimpleCData[int] = None, strides: None | ShapeType = None) -> None:
42+
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None,
43+
dtype: None | Dtype | str = None, shape: None | ShapeType = None,
44+
pointer_source: PointerSource = PointerSource.host, offset: None | ctypes._SimpleCData[int] = None,
45+
strides: None | ShapeType = None) -> None:
4546
_no_initial_dtype = False # HACK, FIXME
4647

4748
# Initialise array object
4849
self.arr = ctypes.c_void_p(0)
4950

5051
if isinstance(dtype, str):
51-
dtype = _str_to_dtype(dtype)
52+
dtype = _str_to_dtype(dtype) # type: ignore[arg-type]
5253

5354
if dtype is None:
5455
_no_initial_dtype = True

arrayfire/array_api/tests/test_array_object.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import array as pyarray
2+
13
import pytest
24

35
from arrayfire.array_api import Array, float32, int16
46
from arrayfire.array_api._dtypes import supported_dtypes
57

68
# TODO change separated methods with setup and teardown to avoid code duplication
9+
# TODO add tests for array arguments: device, offset, strides
710

811

9-
def test_empty_array() -> None:
12+
def test_create_empty_array() -> None:
1013
array = Array()
1114

1215
assert array.dtype == float32
@@ -16,7 +19,7 @@ def test_empty_array() -> None:
1619
assert len(array) == 0
1720

1821

19-
def test_empty_array_with_nonempty_dtype() -> None:
22+
def test_create_empty_array_with_nonempty_dtype() -> None:
2023
array = Array(dtype=int16)
2124

2225
assert array.dtype == int16
@@ -26,7 +29,32 @@ def test_empty_array_with_nonempty_dtype() -> None:
2629
assert len(array) == 0
2730

2831

29-
def test_empty_array_with_nonempty_shape() -> None:
32+
def test_create_empty_array_with_str_dtype() -> None:
33+
array = Array(dtype="short int")
34+
35+
assert array.dtype == int16
36+
assert array.ndim == 0
37+
assert array.size == 0
38+
assert array.shape == ()
39+
assert len(array) == 0
40+
41+
42+
def test_create_empty_array_with_literal_dtype() -> None:
43+
array = Array(dtype="h")
44+
45+
assert array.dtype == int16
46+
assert array.ndim == 0
47+
assert array.size == 0
48+
assert array.shape == ()
49+
assert len(array) == 0
50+
51+
52+
def test_create_empty_array_with_not_matching_str_dtype() -> None:
53+
with pytest.raises(TypeError):
54+
Array(dtype="hello world")
55+
56+
57+
def test_create_empty_array_with_nonempty_shape() -> None:
3058
array = Array(shape=(2, 3))
3159

3260
assert array.dtype == float32
@@ -36,7 +64,7 @@ def test_empty_array_with_nonempty_shape() -> None:
3664
assert len(array) == 2
3765

3866

39-
def test_array_from_1d_list() -> None:
67+
def test_create_array_from_1d_list() -> None:
4068
array = Array([1, 2, 3])
4169

4270
assert array.dtype == float32
@@ -46,11 +74,22 @@ def test_array_from_1d_list() -> None:
4674
assert len(array) == 3
4775

4876

49-
def test_array_from_2d_list() -> None:
77+
def test_create_array_from_2d_list() -> None:
5078
with pytest.raises(TypeError):
5179
Array([[1, 2, 3], [1, 2, 3]])
5280

5381

82+
def test_create_array_from_pyarray() -> None:
83+
py_array = pyarray.array("f", [1, 2, 3])
84+
array = Array(py_array)
85+
86+
assert array.dtype == float32
87+
assert array.ndim == 1
88+
assert array.size == 3
89+
assert array.shape == (3,)
90+
assert len(array) == 3
91+
92+
5493
def test_array_from_list_with_unsupported_dtype() -> None:
5594
for dtype in supported_dtypes:
5695
if dtype == float32:

0 commit comments

Comments
 (0)