Skip to content

Commit 1c9b1aa

Browse files
authored
Merge pull request scipy#21892 from ev-br/is_namespace_helpers
MAINT: `_lib`: use `is_numpy` etc helpers from the compat library
2 parents b7bd4f9 + bf3cc68 commit 1c9b1aa

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

scipy/_lib/_array_api.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
is_array_api_obj,
2323
size as xp_size,
2424
numpy as np_compat,
25-
device as xp_device
25+
device as xp_device,
26+
is_numpy_namespace as is_numpy,
27+
is_cupy_namespace as is_cupy,
28+
is_torch_namespace as is_torch,
29+
is_jax_namespace as is_jax,
30+
is_array_api_strict_namespace as is_array_api_strict
2631
)
2732

2833
__all__ = [
@@ -234,26 +239,6 @@ def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
234239
return _asarray(x, copy=True, xp=xp)
235240

236241

237-
def is_numpy(xp: ModuleType) -> bool:
238-
return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy')
239-
240-
241-
def is_cupy(xp: ModuleType) -> bool:
242-
return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy')
243-
244-
245-
def is_torch(xp: ModuleType) -> bool:
246-
return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch')
247-
248-
249-
def is_jax(xp):
250-
return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api')
251-
252-
253-
def is_array_api_strict(xp):
254-
return xp.__name__ == 'array_api_strict'
255-
256-
257242
def _strict_check(actual, desired, xp, *,
258243
check_namespace=True, check_dtype=True, check_shape=True,
259244
check_0d=True):

0 commit comments

Comments
 (0)