|
22 | 22 | is_array_api_obj,
|
23 | 23 | size as xp_size,
|
24 | 24 | numpy as np_compat,
|
25 |
| - device as xp_device |
| 25 | + device as xp_device, |
| 26 | + is_numpy_namespace as is_numpy, |
| 27 | + is_cupy_namespace as is_cupy, |
| 28 | + is_torch_namespace as is_torch, |
| 29 | + is_jax_namespace as is_jax, |
| 30 | + is_array_api_strict_namespace as is_array_api_strict |
26 | 31 | )
|
27 | 32 |
|
28 | 33 | __all__ = [
|
@@ -234,26 +239,6 @@ def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
|
234 | 239 | return _asarray(x, copy=True, xp=xp)
|
235 | 240 |
|
236 | 241 |
|
237 |
| -def is_numpy(xp: ModuleType) -> bool: |
238 |
| - return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy') |
239 |
| - |
240 |
| - |
241 |
| -def is_cupy(xp: ModuleType) -> bool: |
242 |
| - return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy') |
243 |
| - |
244 |
| - |
245 |
| -def is_torch(xp: ModuleType) -> bool: |
246 |
| - return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch') |
247 |
| - |
248 |
| - |
249 |
| -def is_jax(xp): |
250 |
| - return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api') |
251 |
| - |
252 |
| - |
253 |
| -def is_array_api_strict(xp): |
254 |
| - return xp.__name__ == 'array_api_strict' |
255 |
| - |
256 |
| - |
257 | 242 | def _strict_check(actual, desired, xp, *,
|
258 | 243 | check_namespace=True, check_dtype=True, check_shape=True,
|
259 | 244 | check_0d=True):
|
|
0 commit comments