Skip to content

Commit 9d1a6bc

Browse files
committed
simplify Backend
1 parent 1dcb9e6 commit 9d1a6bc

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

src/array_api_extra/_lib/_backends.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,23 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
1717
Parameters
1818
----------
1919
value : str
20-
String describing the backend.
20+
Name of the backend's module.
2121
is_namespace : Callable[[ModuleType], bool]
2222
Function to check whether an input module is the array namespace
2323
corresponding to the backend.
24-
module_name : str
25-
Name of the backend's module.
2624
"""
2725

28-
ARRAY_API_STRICT = (
29-
"array_api_strict",
30-
_compat.is_array_api_strict_namespace,
31-
"array_api_strict",
32-
)
33-
NUMPY = "numpy", _compat.is_numpy_namespace, "numpy"
34-
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy"
35-
CUPY = "cupy", _compat.is_cupy_namespace, "cupy"
36-
TORCH = "torch", _compat.is_torch_namespace, "torch"
37-
DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array"
38-
SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse"
39-
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy"
26+
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27+
NUMPY = "numpy", _compat.is_numpy_namespace
28+
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
29+
CUPY = "cupy", _compat.is_cupy_namespace
30+
TORCH = "torch", _compat.is_torch_namespace
31+
DASK_ARRAY = "dask.array", _compat.is_dask_namespace
32+
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
33+
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace
4034

4135
def __new__(
42-
cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str
36+
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
4337
): # numpydoc ignore=GL08
4438
obj = object.__new__(cls)
4539
obj._value_ = value
@@ -49,10 +43,8 @@ def __init__(
4943
self,
5044
value: str, # noqa: ARG002 # pylint: disable=unused-argument
5145
is_namespace: Callable[[ModuleType], bool],
52-
module_name: str,
5346
): # numpydoc ignore=GL08
5447
self.is_namespace = is_namespace
55-
self.module_name = module_name
5648

5749
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5850
"""Pretty-print parameterized test names."""

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Pytest fixtures."""
22

3-
from __future__ import annotations
4-
53
from collections.abc import Callable
64
from functools import wraps
75
from types import ModuleType

0 commit comments

Comments
 (0)