@@ -18,23 +18,28 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
1818 ----------
1919 value : str
2020 String describing the backend.
21- library_name : str
22- Name of the array library of the backend.
21+ is_namespace : Callable[[ModuleType], bool]
22+ Function to check whether an input module is the array namespace
23+ corresponding to the backend.
2324 module_name : str
2425 Name of the backend's module.
2526 """
2627
27- ARRAY_API_STRICT = "array_api_strict" , "array_api_strict" , "array_api_strict"
28- NUMPY = "numpy" , "numpy" , "numpy"
29- NUMPY_READONLY = "numpy_readonly" , "numpy" , "numpy"
30- CUPY = "cupy" , "cupy" , "cupy"
31- TORCH = "torch" , "torch" , "torch"
32- DASK_ARRAY = "dask.array" , "dask" , "dask.array"
33- SPARSE = "sparse" , "pydata_sparse" , "sparse"
34- JAX_NUMPY = "jax.numpy" , "jax" , "jax.numpy"
28+ ARRAY_API_STRICT = (
29+ "array_api_strict" ,
30+ _compat .is_array_api_strict_namespace ,
31+ "array_api_strict" ,
32+ )
33+ NUMPY = "numpy" , _compat .is_numpy_namespace , "numpy"
34+ NUMPY_READONLY = "numpy_readonly" , _compat .is_numpy_namespace , "numpy"
35+ CUPY = "cupy" , _compat .is_cupy_namespace , "cupy"
36+ TORCH = "torch" , _compat .is_torch_namespace , "torch"
37+ DASK_ARRAY = "dask.array" , _compat .is_dask_namespace , "dask.array"
38+ SPARSE = "sparse" , _compat .is_pydata_sparse_namespace , "sparse"
39+ JAX_NUMPY = "jax.numpy" , _compat .is_jax_namespace , "jax.numpy"
3540
3641 def __new__ (
37- cls , value : str , _library_name : str , _module_name : str
42+ cls , value : str , _is_namespace : Callable [[ ModuleType ], bool ] , _module_name : str
3843 ): # numpydoc ignore=GL08
3944 obj = object .__new__ (cls )
4045 obj ._value_ = value
@@ -43,30 +48,12 @@ def __new__(
4348 def __init__ (
4449 self ,
4550 value : str , # noqa: ARG002 # pylint: disable=unused-argument
46- library_name : str ,
51+ is_namespace : Callable [[ ModuleType ], bool ] ,
4752 module_name : str ,
4853 ): # numpydoc ignore=GL08
49- self .library_name = library_name
54+ self .is_namespace = is_namespace
5055 self .module_name = module_name
5156
5257 def __str__ (self ) -> str : # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5358 """Pretty-print parameterized test names."""
5459 return cast (str , self .value )
55-
56- def is_namespace (self , xp : ModuleType ) -> bool :
57- """
58- Call the corresponding is_namespace function.
59-
60- Parameters
61- ----------
62- xp : array_namespace
63- Array namespace to check.
64-
65- Returns
66- -------
67- bool
68- ``True`` if xp matches the namespace, ``False`` otherwise.
69- """
70- is_namespace_func = getattr (_compat , f"is_{ self .library_name } _namespace" )
71- is_namespace_func = cast (Callable [[ModuleType ], bool ], is_namespace_func )
72- return is_namespace_func (xp )
0 commit comments