Skip to content

Commit 9daa33a

Browse files
committed
address review
1 parent 4c786a2 commit 9daa33a

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

src/array_api_extra/_lib/_backends.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,28 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
1818
----------
1919
value : str
2020
String describing the backend.
21-
library_name : str
22-
Name of the array library of the backend.
21+
is_namespace : Callable[[ModuleType], bool]
22+
Function to check whether an input module is the array namespace
23+
corresponding to the backend.
2324
module_name : str
2425
Name of the backend's module.
2526
"""
2627

27-
ARRAY_API_STRICT = "array_api_strict", "array_api_strict", "array_api_strict"
28-
NUMPY = "numpy", "numpy", "numpy"
29-
NUMPY_READONLY = "numpy_readonly", "numpy", "numpy"
30-
CUPY = "cupy", "cupy", "cupy"
31-
TORCH = "torch", "torch", "torch"
32-
DASK_ARRAY = "dask.array", "dask", "dask.array"
33-
SPARSE = "sparse", "pydata_sparse", "sparse"
34-
JAX_NUMPY = "jax.numpy", "jax", "jax.numpy"
28+
ARRAY_API_STRICT = (
29+
"array_api_strict",
30+
_compat.is_array_api_strict_namespace,
31+
"array_api_strict",
32+
)
33+
NUMPY = "numpy", _compat.is_numpy_namespace, "numpy"
34+
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy"
35+
CUPY = "cupy", _compat.is_cupy_namespace, "cupy"
36+
TORCH = "torch", _compat.is_torch_namespace, "torch"
37+
DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array"
38+
SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse"
39+
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy"
3540

3641
def __new__(
37-
cls, value: str, _library_name: str, _module_name: str
42+
cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str
3843
): # numpydoc ignore=GL08
3944
obj = object.__new__(cls)
4045
obj._value_ = value
@@ -43,30 +48,12 @@ def __new__(
4348
def __init__(
4449
self,
4550
value: str, # noqa: ARG002 # pylint: disable=unused-argument
46-
library_name: str,
51+
is_namespace: Callable[[ModuleType], bool],
4752
module_name: str,
4853
): # numpydoc ignore=GL08
49-
self.library_name = library_name
54+
self.is_namespace = is_namespace
5055
self.module_name = module_name
5156

5257
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5358
"""Pretty-print parameterized test names."""
5459
return cast(str, self.value)
55-
56-
def is_namespace(self, xp: ModuleType) -> bool:
57-
"""
58-
Call the corresponding is_namespace function.
59-
60-
Parameters
61-
----------
62-
xp : array_namespace
63-
Array namespace to check.
64-
65-
Returns
66-
-------
67-
bool
68-
``True`` if xp matches the namespace, ``False`` otherwise.
69-
"""
70-
is_namespace_func = getattr(_compat, f"is_{self.library_name}_namespace")
71-
is_namespace_func = cast(Callable[[ModuleType], bool], is_namespace_func)
72-
return is_namespace_func(xp)

src/array_api_extra/_lib/_utils/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace,
88
device,
9+
is_array_api_strict_namespace,
910
is_cupy_namespace,
11+
is_dask_namespace,
1012
is_jax_array,
1113
is_jax_namespace,
1214
is_numpy_namespace,
@@ -19,7 +21,9 @@
1921
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
2022
array_namespace,
2123
device,
24+
is_array_api_strict_namespace,
2225
is_cupy_namespace,
26+
is_dask_namespace,
2327
is_jax_array,
2428
is_jax_namespace,
2529
is_numpy_namespace,
@@ -32,7 +36,9 @@
3236
__all__ = [
3337
"array_namespace",
3438
"device",
39+
"is_array_api_strict_namespace",
3540
"is_cupy_namespace",
41+
"is_dask_namespace",
3642
"is_jax_array",
3743
"is_jax_namespace",
3844
"is_numpy_namespace",

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def array_namespace(
1818
use_compat: bool | None = None,
1919
) -> ArrayModule: ...
2020
def device(x: Array, /) -> Device: ...
21+
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
2122
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
23+
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
2224
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2325
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
2426
def is_torch_namespace(xp: ModuleType, /) -> bool: ...

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def test_ndim(self, xp: ModuleType):
421421
def test_mode_not_implemented(self, xp: ModuleType):
422422
a = xp.arange(3)
423423
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
424-
pad(a, 2, mode="edge") # type: ignore[arg-type]
424+
pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
425425

426426
def test_device(self, xp: ModuleType, device: Device):
427427
a = xp.asarray(0.0, device=device)

0 commit comments

Comments
 (0)