Skip to content

Commit 247ee6d

Browse files
committed
Reverts and tweaks
1 parent a06d51f commit 247ee6d

File tree

8 files changed

+27
-27
lines changed

8 files changed

+27
-27
lines changed

array_api_compat/common/_aliases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import inspect
88
from collections.abc import Sequence
9-
from types import NoneType
109
from typing import TYPE_CHECKING, Any, NamedTuple, cast
1110

1211
from ._helpers import _check_device, array_namespace
@@ -384,7 +383,7 @@ def clip(
384383
out: Array | None = None,
385384
) -> Array:
386385
def _isscalar(a: object) -> TypeIs[float | None]:
387-
return isinstance(a, int | float | NoneType)
386+
return isinstance(a, int | float) or a is None
388387

389388
min_shape = () if _isscalar(min) else min.shape
390389
max_shape = () if _isscalar(max) else max.shape

array_api_compat/common/_helpers.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import math
1313
import sys
1414
import warnings
15-
from collections.abc import Hashable
15+
from collections.abc import Collection, Hashable
1616
from functools import lru_cache
17-
from types import NoneType
1817
from typing import (
1918
TYPE_CHECKING,
2019
Any,
2120
Final,
2221
Literal,
22+
SupportsIndex,
2323
TypeAlias,
2424
TypeGuard,
2525
cast,
@@ -51,7 +51,7 @@
5151
| ndx.Array
5252
| sparse.SparseArray
5353
| torch.Tensor
54-
| SupportsArrayNamespace
54+
| SupportsArrayNamespace[Any]
5555
)
5656

5757
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
@@ -630,9 +630,9 @@ def your_function(x, y):
630630
raise ValueError(
631631
"The given array does not have an array-api-compat wrapper"
632632
)
633-
x = cast(SupportsArrayNamespace, x)
633+
x = cast("SupportsArrayNamespace[Any]", x)
634634
namespaces.add(x.__array_namespace__(api_version=api_version))
635-
elif isinstance(x, int | float | complex | NoneType):
635+
elif isinstance(x, int | float | complex) or x is None:
636636
continue
637637
else:
638638
# TODO: Support Python scalars?
@@ -890,12 +890,10 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
890890

891891

892892
@overload
893-
def size(x: HasShape[int]) -> int: ...
893+
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
894894
@overload
895-
def size(x: HasShape[int | None]) -> int | None: ...
896-
@overload
897-
def size(x: HasShape[float]) -> int | None: ... # Dask special case
898-
def size(x: HasShape[float | None]) -> int | None:
895+
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
896+
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
899897
"""
900898
Return the total number of elements of x.
901899
@@ -910,9 +908,9 @@ def size(x: HasShape[float | None]) -> int | None:
910908
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
911909
if None in x.shape:
912910
return None
913-
out = math.prod(cast(tuple[float, ...], x.shape))
911+
out = math.prod(cast("Collection[SupportsIndex]", x.shape))
914912
# dask.array.Array.shape can contain NaN
915-
return None if math.isnan(out) else cast(int, out)
913+
return None if math.isnan(out) else out
916914

917915

918916
@lru_cache(100)
@@ -1003,7 +1001,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10031001
# on __bool__ (dask is one such example, which however is special-cased above).
10041002

10051003
# Select a single point of the array
1006-
s = size(cast(HasShape, x))
1004+
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
10071005
if s is None:
10081006
return True
10091007
xp = array_namespace(x)

array_api_compat/common/_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,14 @@ def vector_norm(
187187
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
188188
# above to avoid matrix norm logic.
189189
shape = list(x.shape)
190-
_axis = cast(
190+
axes = cast(
191191
"tuple[int, ...]",
192192
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
193193
range(x.ndim) if axis is None else axis,
194194
x.ndim,
195195
),
196196
)
197-
for i in _axis:
197+
for i in axes:
198198
shape[i] = 1
199199
res = xp.reshape(res, tuple(shape))
200200

array_api_compat/common/_typing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
6161
def __len__(self, /) -> int: ...
6262

6363

64-
class SupportsArrayNamespace(Protocol):
65-
def __array_namespace__(self, /, *, api_version: str | None) -> Namespace: ...
64+
class SupportsArrayNamespace(Protocol[_T_co]):
65+
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
6666

6767

6868
class HasShape(Protocol[_T_co]):
6969
@property
70-
def shape(self, /) -> tuple[_T_co, ...]: ...
70+
def shape(self, /) -> _T_co: ...
7171

7272

7373
# Return type of `__array_namespace_info__.default_dtypes`

array_api_compat/cupy/_aliases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from builtins import bool as py_bool
44

55
import cupy as cp
6+
67
from ..common import _aliases, _helpers
78
from ..common._typing import NestedSequence, SupportsBufferProtocol
89
from .._internal import get_xp
@@ -119,7 +120,7 @@ def count_nonzero(
119120

120121

121122
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
122-
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
123+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
123124
return cp.take_along_axis(x, indices, axis=axis)
124125

125126

array_api_compat/numpy/_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ def count_nonzero(
119119
) -> Array:
120120
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
121121
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
122-
result = cast(Any, np.count_nonzero(x, axis=axis, keepdims=keepdims)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
122+
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
123123
if axis is None and not keepdims:
124124
return np.asarray(result)
125125
return result
126126

127127

128128
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
129-
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
129+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
130130
return np.take_along_axis(x, indices, axis=axis)
131131

132132

array_api_compat/numpy/_info.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
more details.
88
99
"""
10-
1110
from __future__ import annotations
1211

1312
from numpy import bool_ as bool
@@ -64,7 +63,7 @@ class __array_namespace_info__:
6463
6564
"""
6665

67-
__module__ = "numpy"
66+
__module__ = 'numpy'
6867

6968
def capabilities(self):
7069
"""
@@ -183,7 +182,8 @@ def default_dtypes(
183182
"""
184183
if device not in ["cpu", None]:
185184
raise ValueError(
186-
f'Device not understood. Only "cpu" is allowed, but received: {device}'
185+
'Device not understood. Only "cpu" is allowed, but received:'
186+
f' {device}'
187187
)
188188
return {
189189
"real floating": dtype(float64),
@@ -254,7 +254,8 @@ def dtypes(
254254
"""
255255
if device not in ["cpu", None]:
256256
raise ValueError(
257-
f'Device not understood. Only "cpu" is allowed, but received: {device}'
257+
'Device not understood. Only "cpu" is allowed, but received:'
258+
f' {device}'
258259
)
259260
if kind is None:
260261
return {

array_api_compat/numpy/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Device: TypeAlias = Literal["cpu"]
88

99
if TYPE_CHECKING:
10+
1011
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
1112
DType: TypeAlias = np.dtype[
1213
np.bool_

0 commit comments

Comments
 (0)