Skip to content

Commit d758d6f

Browse files
jorenhamlucascolley
andcommitted
STY: apply review suggestions
Co-authored-by: lucascolley <[email protected]>
1 parent 4f6ef6d commit d758d6f

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

array_api_compat/common/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
652652
if isdtype(x.dtype, "complex floating", xp=xp):
653653
out = (x / xp.abs(x, **kwargs))[...]
654654
# sign(0) = 0 but the above formula would give nan
655-
out[x == 0 + 0j] = 0 + 0j
655+
out[x == 0j] = 0j
656656
else:
657657
out = xp.sign(x, **kwargs)
658658
# CuPy sign() does not propagate nans. See

array_api_compat/common/_helpers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import (
1717
TYPE_CHECKING,
1818
Any,
19+
Final,
1920
Literal,
2021
SupportsIndex,
2122
TypeAlias,
@@ -56,6 +57,9 @@
5657
| _CupyArray
5758
)
5859

60+
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
61+
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
62+
5963

6064
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
6165
"""Return True if `x` is a zero-gradient array.
@@ -477,16 +481,11 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool:
477481

478482

479483
def _check_api_version(api_version: str | None) -> None:
480-
if api_version in ["2021.12", "2022.12", "2023.12"]:
484+
if api_version in _API_VERSIONS_OLD:
481485
warnings.warn(
482486
f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
483487
)
484-
elif api_version is not None and api_version not in [
485-
"2021.12",
486-
"2022.12",
487-
"2023.12",
488-
"2024.12",
489-
]:
488+
elif api_version is not None and api_version not in _API_VERSIONS:
490489
raise ValueError(
491490
"Only the 2024.12 version of the array API specification is currently supported"
492491
)

array_api_compat/dask/array/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import dask.array as da
66

7-
# These functions are in both the main and linalg namespaces
7+
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
88
from dask.array import matmul, outer, tensordot
99

1010
# Exports

array_api_compat/numpy/_aliases.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
if TYPE_CHECKING:
1616
from typing_extensions import Buffer, TypeIs
1717

18+
# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
19+
# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
1820
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
1921

2022
bool = np.bool_
@@ -130,7 +132,9 @@ def count_nonzero(
130132
axis: int | tuple[int, ...] | None = None,
131133
keepdims: py_bool = False,
132134
) -> Array:
133-
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore
135+
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
136+
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
137+
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
134138
if axis is None and not keepdims:
135139
return np.asarray(result)
136140
return result

0 commit comments

Comments
 (0)