Skip to content

Commit 083c63e

Browse files
authored
Misc array API improvements (#798)
* Workaround due to shadowing builtin bool * finfo and iinfo can take an array as well as a dtype * Make default unsigned int array API compliant * Remove assumption that arrays have `nbytes` property
1 parent afbae75 commit 083c63e

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

cubed/array_api/data_type_functions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@ def can_cast(from_, to, /):
1919

2020

2121
def finfo(type, /):
22-
return nxp.finfo(type)
22+
return nxp.finfo(_as_dtype(type))
2323

2424

2525
def iinfo(type, /):
26-
return nxp.iinfo(type)
26+
return nxp.iinfo(_as_dtype(type))
2727

2828

2929
def isdtype(dtype, kind):
3030
return nxp.isdtype(dtype, kind)
3131

3232

3333
def result_type(*arrays_and_dtypes):
34-
return nxp.result_type(
35-
*(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes)
36-
)
34+
return nxp.result_type(*(_as_dtype(a) for a in arrays_and_dtypes))
35+
36+
37+
def _as_dtype(type):
38+
return type.dtype if isinstance(type, CoreArray) else type

cubed/array_api/dtypes.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copied from numpy.array_api
2+
import builtins
3+
24
from cubed.array_api.inspection import __array_namespace_info__
35
from cubed.backend_array_api import namespace as nxp
46

@@ -111,9 +113,16 @@ def _upcast_integral_dtypes(
111113
elif x.dtype in _signed_integer_dtypes:
112114
dtype = dtypes["integral"]
113115
elif x.dtype in _unsigned_integer_dtypes:
114-
# Type arithmetic to produce an unsigned integer dtype at the same default precision.
115-
default_bits = nxp.iinfo(dtypes["integral"]).bits
116-
dtype = nxp.dtype(f"u{default_bits // 8}")
116+
# produce an unsigned integer dtype at the same default precision
117+
dtype = dtypes["integral"]
118+
if dtype == int8:
119+
dtype = uint8
120+
elif dtype == int16:
121+
dtype = uint16
122+
elif dtype == int32:
123+
dtype = uint32
124+
elif dtype == int64:
125+
dtype = uint64
117126
else:
118127
dtype = x.dtype
119128

@@ -122,8 +131,8 @@ def _upcast_integral_dtypes(
122131

123132
def _promote_scalars(x1, x2, op):
124133
"""Promote at most one of x1 or x2 to an array from a Python scalar"""
125-
x1_is_scalar = isinstance(x1, (int, float, complex, bool))
126-
x2_is_scalar = isinstance(x2, (int, float, complex, bool))
134+
x1_is_scalar = isinstance(x1, (int, float, complex, builtins.bool))
135+
x2_is_scalar = isinstance(x2, (int, float, complex, builtins.bool))
127136
if x1_is_scalar and x2_is_scalar:
128137
raise TypeError(f"At least one of x1 and x2 must be an array in {op}")
129138
elif x1_is_scalar:

cubed/storage/virtual.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,10 @@ def __init__(
8686
chunks: T_RegularChunks,
8787
max_nbytes: int = 10**6,
8888
):
89-
if array.nbytes > max_nbytes:
89+
nbytes = array_memory(array.dtype, array.shape)
90+
if nbytes > max_nbytes:
9091
raise ValueError(
91-
f"Size of in memory array is {memory_repr(array.nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
92+
f"Size of in memory array is {memory_repr(nbytes)} which exceeds maximum of {memory_repr(max_nbytes)}. Consider loading the array from storage using `from_array`."
9293
)
9394
self.array = array
9495
super().__init__(array.shape, array.dtype, chunks)

0 commit comments

Comments
 (0)