Skip to content

Commit c33ee09

Browse files
authored
Make itemsize support other Array API implementations (#804)
1 parent 4430852 commit c33ee09

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

cubed/backend_array_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,17 @@ def numpy_array_to_backend_array(arr, *, dtype=None):
5959
return namespace.asarray(arr, dtype=dtype)
6060

6161

62+
def backend_dtype_to_numpy_dtype(dtype):
63+
if isinstance(dtype, np.dtype):
64+
return dtype
65+
elif isinstance(dtype, list):
66+
return np.dtype(
67+
[(field[0], backend_dtype_to_numpy_dtype(field[1])) for field in dtype]
68+
)
69+
else:
70+
a = namespace.empty((), dtype=dtype)
71+
return np.dtype(backend_array_to_numpy_array(a).dtype)
72+
73+
6274
# jax doesn't support in-place assignment, so we use .at[].set() instead.
6375
IS_IMMUTABLE_ARRAY = "jax" in xp_name

cubed/tests/test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
extract_stack_summaries,
1515
is_cloud_storage_path,
1616
is_local_path,
17+
itemsize,
1718
join_path,
1819
map_nested,
1920
memory_repr,
@@ -32,6 +33,15 @@ def test_array_memory():
3233
assert array_memory(np.int32, (0,)) == 0
3334

3435

36+
def test_itemsize():
37+
assert itemsize(np.bool) == 1
38+
assert itemsize(np.int32) == 4
39+
assert itemsize(np.int64) == 8
40+
assert itemsize(np.dtype(np.int64)) == 8
41+
assert itemsize([("a", np.int32), ("b", np.int64)]) == 12
42+
assert itemsize(np.dtype([("a", np.int32), ("b", np.int64)])) == 12
43+
44+
3545
def test_block_id_to_offset():
3646
numblocks = (5, 3)
3747
for block_id in itertools.product(*[list(range(n)) for n in numblocks]):

cubed/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import inspect
23
import itertools
34
import numbers
45
import platform
@@ -20,6 +21,7 @@
2021
import tlz as toolz
2122
from toolz import reduce
2223

24+
from cubed.backend_array_api import backend_dtype_to_numpy_dtype
2325
from cubed.backend_array_api import namespace as nxp
2426
from cubed.types import T_Chunks, T_DType, T_RectangularChunks, T_RegularChunks, T_Shape
2527
from cubed.vendor.dask.array.core import _check_regular_chunks
@@ -369,7 +371,14 @@ def normalize_dtype(dtype, device=None) -> T_DType:
369371

370372

371373
def itemsize(dtype: T_DType) -> int:
372-
return dtype.itemsize
374+
"""Return the length of one array element in bytes."""
375+
if hasattr(dtype, "itemsize") and not inspect.isdatadescriptor(dtype.itemsize):
376+
return dtype.itemsize
377+
elif isinstance(dtype, list):
378+
return sum(itemsize(v) for _, v in dtype)
379+
else:
380+
# if dtype has no itemsize property, convert to numpy and use its itemsize
381+
return backend_dtype_to_numpy_dtype(dtype).itemsize
373382

374383

375384
def normalize_chunks(

0 commit comments

Comments
 (0)