Skip to content

Commit 16c4dd8

Browse files
committed
Use separate wrapped dtype objects in numpy.array_api
This way there is no ambiguity about the fact the non-portability of NumPy dtype behavior, or the fact that NumPy dtypes are not necessarily allowed as dtypes for non-NumPy array APIs. Fixes #23883 Original NumPy Commit: 13ab654e46110221b6388aaad606a3625f43db5a
1 parent 1a13e76 commit 16c4dd8

9 files changed

+106
-46
lines changed

array_api_strict/_array_object.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from enum import IntEnum
2020
from ._creation_functions import asarray
2121
from ._dtypes import (
22+
_DType,
2223
_all_dtypes,
2324
_boolean_dtypes,
2425
_integer_dtypes,
@@ -81,11 +82,13 @@ def _new(cls, x, /):
8182
if isinstance(x, np.generic):
8283
# Convert the array scalar to a 0-D array
8384
x = np.asarray(x)
84-
if x.dtype not in _all_dtypes:
85+
_dtype = _DType(x.dtype)
86+
if _dtype not in _all_dtypes:
8587
raise TypeError(
8688
f"The array_api namespace does not support the dtype '{x.dtype}'"
8789
)
8890
obj._array = x
91+
obj._dtype = _dtype
8992
return obj
9093

9194
# Prevent Array() from working
@@ -107,7 +110,7 @@ def __repr__(self: Array, /) -> str:
107110
"""
108111
Performs the operation __repr__.
109112
"""
110-
suffix = f", dtype={self.dtype.name})"
113+
suffix = f", dtype={self.dtype})"
111114
if 0 in self.shape:
112115
prefix = "empty("
113116
mid = str(self.shape)
@@ -182,6 +185,7 @@ def _promote_scalar(self, scalar):
182185
integer that is too large to fit in a NumPy integer dtype, or
183186
TypeError when the scalar type is incompatible with the dtype of self.
184187
"""
188+
from ._data_type_functions import iinfo
185189
# Note: Only Python scalar types that match the array dtype are
186190
# allowed.
187191
if isinstance(scalar, bool):
@@ -195,7 +199,7 @@ def _promote_scalar(self, scalar):
195199
"Python int scalars cannot be promoted with bool arrays"
196200
)
197201
if self.dtype in _integer_dtypes:
198-
info = np.iinfo(self.dtype)
202+
info = iinfo(self.dtype)
199203
if not (info.min <= scalar <= info.max):
200204
raise OverflowError(
201205
"Python int scalars must be within the bounds of the dtype for integer arrays"
@@ -221,7 +225,7 @@ def _promote_scalar(self, scalar):
221225
# behavior for integers within the bounds of the integer dtype.
222226
# Outside of those bounds we use the default NumPy behavior (either
223227
# cast or raise OverflowError).
224-
return Array._new(np.array(scalar, self.dtype))
228+
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype))
225229

226230
@staticmethod
227231
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
@@ -331,7 +335,9 @@ def _validate_index(self, key):
331335
for i in _key:
332336
if i is not None:
333337
nonexpanding_key.append(i)
334-
if isinstance(i, Array) or isinstance(i, np.ndarray):
338+
if isinstance(i, np.ndarray):
339+
raise IndexError("Index arrays for np.array_api must be np.array_api arrays")
340+
if isinstance(i, Array):
335341
if i.dtype in _boolean_dtypes:
336342
key_has_mask = True
337343
single_axes.append(i)
@@ -1084,7 +1090,7 @@ def dtype(self) -> Dtype:
10841090
10851091
See its docstring for more information.
10861092
"""
1087-
return self._array.dtype
1093+
return self._dtype
10881094

10891095
@property
10901096
def device(self) -> Device:

array_api_strict/_creation_functions.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,15 @@
1212
SupportsBufferProtocol,
1313
)
1414
from collections.abc import Sequence
15-
from ._dtypes import _all_dtypes
15+
from ._dtypes import _DType, _all_dtypes
1616

1717
import numpy as np
1818

1919

2020
def _check_valid_dtype(dtype):
2121
# Note: Only spelling dtypes as the dtype objects is supported.
22-
23-
# We use this instead of "dtype in _all_dtypes" because the dtype objects
24-
# define equality with the sorts of things we want to disallow.
25-
for d in (None,) + _all_dtypes:
26-
if dtype is d:
27-
return
28-
raise ValueError("dtype must be one of the supported dtypes")
22+
if not dtype in (None,) + _all_dtypes:
23+
raise ValueError("dtype must be one of the supported dtypes")
2924

3025

3126
def asarray(
@@ -68,6 +63,8 @@ def asarray(
6863
# Give a better error message in this case. NumPy would convert this
6964
# to an object array. TODO: This won't handle large integers in lists.
7065
raise OverflowError("Integer out of bounds for array dtypes")
66+
if dtype is not None:
67+
dtype = dtype._np_dtype
7168
res = np.asarray(obj, dtype=dtype)
7269
return Array._new(res)
7370

@@ -91,6 +88,8 @@ def arange(
9188
_check_valid_dtype(dtype)
9289
if device not in [CPU_DEVICE, None]:
9390
raise ValueError(f"Unsupported device {device!r}")
91+
if dtype is not None:
92+
dtype = dtype._np_dtype
9493
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
9594

9695

@@ -110,6 +109,8 @@ def empty(
110109
_check_valid_dtype(dtype)
111110
if device not in [CPU_DEVICE, None]:
112111
raise ValueError(f"Unsupported device {device!r}")
112+
if dtype is not None:
113+
dtype = dtype._np_dtype
113114
return Array._new(np.empty(shape, dtype=dtype))
114115

115116

@@ -126,6 +127,8 @@ def empty_like(
126127
_check_valid_dtype(dtype)
127128
if device not in [CPU_DEVICE, None]:
128129
raise ValueError(f"Unsupported device {device!r}")
130+
if dtype is not None:
131+
dtype = dtype._np_dtype
129132
return Array._new(np.empty_like(x._array, dtype=dtype))
130133

131134

@@ -148,6 +151,8 @@ def eye(
148151
_check_valid_dtype(dtype)
149152
if device not in [CPU_DEVICE, None]:
150153
raise ValueError(f"Unsupported device {device!r}")
154+
if dtype is not None:
155+
dtype = dtype._np_dtype
151156
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
152157

153158

@@ -176,8 +181,10 @@ def full(
176181
raise ValueError(f"Unsupported device {device!r}")
177182
if isinstance(fill_value, Array) and fill_value.ndim == 0:
178183
fill_value = fill_value._array
184+
if dtype is not None:
185+
dtype = dtype._np_dtype
179186
res = np.full(shape, fill_value, dtype=dtype)
180-
if res.dtype not in _all_dtypes:
187+
if _DType(res.dtype) not in _all_dtypes:
181188
# This will happen if the fill value is not something that NumPy
182189
# coerces to one of the acceptable dtypes.
183190
raise TypeError("Invalid input to full")
@@ -202,8 +209,10 @@ def full_like(
202209
_check_valid_dtype(dtype)
203210
if device not in [CPU_DEVICE, None]:
204211
raise ValueError(f"Unsupported device {device!r}")
212+
if dtype is not None:
213+
dtype = dtype._np_dtype
205214
res = np.full_like(x._array, fill_value, dtype=dtype)
206-
if res.dtype not in _all_dtypes:
215+
if _DType(res.dtype) not in _all_dtypes:
207216
# This will happen if the fill value is not something that NumPy
208217
# coerces to one of the acceptable dtypes.
209218
raise TypeError("Invalid input to full_like")
@@ -230,6 +239,8 @@ def linspace(
230239
_check_valid_dtype(dtype)
231240
if device not in [CPU_DEVICE, None]:
232241
raise ValueError(f"Unsupported device {device!r}")
242+
if dtype is not None:
243+
dtype = dtype._np_dtype
233244
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
234245

235246

@@ -269,6 +280,8 @@ def ones(
269280
_check_valid_dtype(dtype)
270281
if device not in [CPU_DEVICE, None]:
271282
raise ValueError(f"Unsupported device {device!r}")
283+
if dtype is not None:
284+
dtype = dtype._np_dtype
272285
return Array._new(np.ones(shape, dtype=dtype))
273286

274287

@@ -285,6 +298,8 @@ def ones_like(
285298
_check_valid_dtype(dtype)
286299
if device not in [CPU_DEVICE, None]:
287300
raise ValueError(f"Unsupported device {device!r}")
301+
if dtype is not None:
302+
dtype = dtype._np_dtype
288303
return Array._new(np.ones_like(x._array, dtype=dtype))
289304

290305

@@ -332,6 +347,8 @@ def zeros(
332347
_check_valid_dtype(dtype)
333348
if device not in [CPU_DEVICE, None]:
334349
raise ValueError(f"Unsupported device {device!r}")
350+
if dtype is not None:
351+
dtype = dtype._np_dtype
335352
return Array._new(np.zeros(shape, dtype=dtype))
336353

337354

@@ -348,4 +365,6 @@ def zeros_like(
348365
_check_valid_dtype(dtype)
349366
if device not in [CPU_DEVICE, None]:
350367
raise ValueError(f"Unsupported device {device!r}")
368+
if dtype is not None:
369+
dtype = dtype._np_dtype
351370
return Array._new(np.zeros_like(x._array, dtype=dtype))

array_api_strict/_data_type_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import (
5+
_DType,
56
_all_dtypes,
67
_boolean_dtypes,
78
_signed_integer_dtypes,
@@ -27,7 +28,7 @@
2728
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
2829
if not copy and dtype == x.dtype:
2930
return x
30-
return Array._new(x._array.astype(dtype=dtype, copy=copy))
31+
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy))
3132

3233

3334
def broadcast_arrays(*arrays: Array) -> List[Array]:
@@ -107,6 +108,8 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
107108
108109
See its docstring for more information.
109110
"""
111+
if isinstance(type, _DType):
112+
type = type._np_dtype
110113
fi = np.finfo(type)
111114
# Note: The types of the float data here are float, whereas in NumPy they
112115
# are scalars of the corresponding float dtype.
@@ -126,6 +129,8 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
126129
127130
See its docstring for more information.
128131
"""
132+
if isinstance(type, _DType):
133+
type = type._np_dtype
129134
ii = np.iinfo(type)
130135
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
131136

array_api_strict/_dtypes.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,42 @@
11
import numpy as np
22

3-
# Note: we use dtype objects instead of dtype classes. The spec does not
4-
# require any behavior on dtypes other than equality.
5-
int8 = np.dtype("int8")
6-
int16 = np.dtype("int16")
7-
int32 = np.dtype("int32")
8-
int64 = np.dtype("int64")
9-
uint8 = np.dtype("uint8")
10-
uint16 = np.dtype("uint16")
11-
uint32 = np.dtype("uint32")
12-
uint64 = np.dtype("uint64")
13-
float32 = np.dtype("float32")
14-
float64 = np.dtype("float64")
15-
complex64 = np.dtype("complex64")
16-
complex128 = np.dtype("complex128")
3+
# Note: we wrap the NumPy dtype objects in a bare class, so that none of the
4+
# additional methods and behaviors of NumPy dtype objects are exposed.
5+
6+
class _DType:
7+
def __init__(self, np_dtype):
8+
np_dtype = np.dtype(np_dtype)
9+
self._np_dtype = np_dtype
10+
11+
def __repr__(self):
12+
return f"np.array_api.{self._np_dtype.name}"
13+
14+
def __eq__(self, other):
15+
if not isinstance(other, _DType):
16+
return NotImplemented
17+
return self._np_dtype == other._np_dtype
18+
19+
def __hash__(self):
20+
# Note: this is not strictly required
21+
# (https://github.com/data-apis/array-api/issues/582), but makes the
22+
# dtype objects much easier to work with here and elsewhere if they
23+
# can be used as dict keys.
24+
return hash(self._np_dtype)
25+
26+
int8 = _DType("int8")
27+
int16 = _DType("int16")
28+
int32 = _DType("int32")
29+
int64 = _DType("int64")
30+
uint8 = _DType("uint8")
31+
uint16 = _DType("uint16")
32+
uint32 = _DType("uint32")
33+
uint64 = _DType("uint64")
34+
float32 = _DType("float32")
35+
float64 = _DType("float64")
36+
complex64 = _DType("complex64")
37+
complex128 = _DType("complex128")
1738
# Note: This name is changed
18-
bool = np.dtype("bool")
39+
bool = _DType("bool")
1940

2041
_all_dtypes = (
2142
int8,

array_api_strict/_manipulation_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def concat(
2020
# (no for scalars with axis=None, no cross-kind casting)
2121
dtype = result_type(*arrays)
2222
arrays = tuple(a._array for a in arrays)
23-
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
23+
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype))
2424

2525

2626
def expand_dims(x: Array, /, *, axis: int) -> Array:
@@ -53,8 +53,8 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
5353

5454

5555
# Note: the optional argument is called 'shape', not 'newshape'
56-
def reshape(x: Array,
57-
/,
56+
def reshape(x: Array,
57+
/,
5858
shape: Tuple[int, ...],
5959
*,
6060
copy: Optional[Bool] = None) -> Array:

array_api_strict/_statistical_functions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def prod(
6767
# special-case it here
6868
if dtype is None:
6969
if x.dtype == float32:
70-
dtype = float64
70+
dtype = np.float64
7171
elif x.dtype == complex64:
72-
dtype = complex128
72+
dtype = np.complex128
73+
else:
74+
dtype = dtype._np_dtype
7375
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
7476

7577

@@ -102,9 +104,11 @@ def sum(
102104
# special-case it here
103105
if dtype is None:
104106
if x.dtype == float32:
105-
dtype = float64
107+
dtype = np.float64
106108
elif x.dtype == complex64:
107-
dtype = complex128
109+
dtype = np.complex128
110+
else:
111+
dtype = dtype._np_dtype
108112
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
109113

110114

array_api_strict/linalg.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
complex64,
99
complex128
1010
)
11+
from ._data_type_functions import finfo
1112
from ._manipulation_functions import reshape
1213
from ._array_object import Array
1314

@@ -204,7 +205,7 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A
204205
raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
205206
S = np.linalg.svd(x._array, compute_uv=False)
206207
if rtol is None:
207-
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
208+
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * finfo(S.dtype).eps
208209
else:
209210
if isinstance(rtol, Array):
210211
rtol = rtol._array
@@ -254,7 +255,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
254255
# Note: this is different from np.linalg.pinv, which does not multiply the
255256
# default tolerance by max(M, N).
256257
if rtol is None:
257-
rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
258+
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
258259
return Array._new(np.linalg.pinv(x._array, rcond=rtol))
259260

260261
def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
@@ -384,9 +385,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr
384385
# _statistical_functions.py)
385386
if dtype is None:
386387
if x.dtype == float32:
387-
dtype = float64
388+
dtype = np.float64
388389
elif x.dtype == complex64:
389-
dtype = complex128
390+
dtype = np.complex128
391+
else:
392+
dtype = dtype._np_dtype
390393
# Note: trace always operates on the last two axes, whereas np.trace
391394
# operates on the first two axes by default
392395
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)))

0 commit comments

Comments
 (0)