Skip to content

Commit 118fedd

Browse files
committed
Factor out normalize_dtype function
1 parent 35bca63 commit 118fedd

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

cubed/storage/types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from itertools import starmap
44
from operator import mul
55

6-
import numpy as np
7-
86
from cubed.types import T_DType, T_RegularChunks, T_Shape
7+
from cubed.utils import normalize_dtype
98

109

1110
class ArrayMetadata:
@@ -16,7 +15,7 @@ def __init__(
1615
chunks: T_RegularChunks,
1716
):
1817
self.shape = shape
19-
self.dtype = np.dtype(dtype)
18+
self.dtype = normalize_dtype(dtype)
2019
self.chunks = chunks
2120

2221
@property

cubed/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def array_memory(dtype: T_DType, shape: T_Shape) -> int:
3131
"""Calculate the amount of memory in bytes that an array uses."""
32-
return np.dtype(dtype).itemsize * prod(shape)
32+
return normalize_dtype(dtype).itemsize * prod(shape)
3333

3434

3535
def chunk_memory(arr) -> int:
@@ -355,3 +355,13 @@ def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]
355355
shape = cast(Tuple[int, ...], shape)
356356
shape = tuple(int(s) for s in shape)
357357
return shape
358+
359+
360+
def normalize_dtype(dtype, device=None) -> T_DType:
361+
"""Normalize a `dtype` argument to the underlying backend array API dtype.
362+
363+
This allows dtypes to be specified as `bool`, `int`, `float`, or `complex`,
364+
or a string (if the array API implementation supports it).
365+
"""
366+
367+
return np.dtype(dtype)

0 commit comments

Comments
 (0)