|
4 | 4 | from types import ModuleType |
5 | 5 | from typing import Literal |
6 | 6 |
|
7 | | -from ._lib import Backend, _funcs |
8 | | -from ._lib._utils._compat import array_namespace |
| 7 | +from ._lib import _funcs |
| 8 | +from ._lib._utils._compat import ( |
| 9 | + array_namespace, |
| 10 | + is_cupy_namespace, |
| 11 | + is_dask_namespace, |
| 12 | + is_jax_namespace, |
| 13 | + is_numpy_namespace, |
| 14 | + is_pydata_sparse_namespace, |
| 15 | + is_torch_namespace, |
| 16 | +) |
9 | 17 | from ._lib._utils._helpers import asarrays |
10 | 18 | from ._lib._utils._typing import Array |
11 | 19 |
|
12 | 20 | __all__ = ["isclose", "pad"] |
13 | 21 |
|
14 | 22 |
|
15 | | -def _delegate(xp: ModuleType, *backends: Backend) -> bool: |
16 | | - """ |
17 | | - Check whether `xp` is one of the `backends` to delegate to. |
18 | | -
|
19 | | - Parameters |
20 | | - ---------- |
21 | | - xp : array_namespace |
22 | | - Array namespace to check. |
23 | | - *backends : IsNamespace |
24 | | - Arbitrarily many backends (from the ``IsNamespace`` enum) to check. |
25 | | -
|
26 | | - Returns |
27 | | - ------- |
28 | | - bool |
29 | | - ``True`` if `xp` matches one of the `backends`, ``False`` otherwise. |
30 | | - """ |
31 | | - return any(backend.is_namespace(xp) for backend in backends) |
32 | | - |
33 | | - |
34 | 23 | def isclose( |
35 | 24 | a: Array | complex, |
36 | 25 | b: Array | complex, |
@@ -108,10 +97,15 @@ def isclose( |
108 | 97 | """ |
109 | 98 | xp = array_namespace(a, b) if xp is None else xp |
110 | 99 |
|
111 | | - if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): |
| 100 | + if ( |
| 101 | + is_numpy_namespace(xp) |
| 102 | + or is_cupy_namespace(xp) |
| 103 | + or is_dask_namespace(xp) |
| 104 | + or is_jax_namespace(xp) |
| 105 | + ): |
112 | 106 | return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) |
113 | 107 |
|
114 | | - if _delegate(xp, Backend.TORCH): |
| 108 | + if is_torch_namespace(xp): |
115 | 109 | a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support |
116 | 110 | return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) |
117 | 111 |
|
@@ -159,14 +153,19 @@ def pad( |
159 | 153 | msg = "Only `'constant'` mode is currently supported" |
160 | 154 | raise NotImplementedError(msg) |
161 | 155 |
|
| 156 | + if ( |
| 157 | + is_numpy_namespace(xp) |
| 158 | + or is_cupy_namespace(xp) |
| 159 | + or is_jax_namespace(xp) |
| 160 | + or is_pydata_sparse_namespace(xp) |
| 161 | + ): |
| 162 | + return xp.pad(x, pad_width, mode, constant_values=constant_values) |
| 163 | + |
162 | 164 | # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 |
163 | | - if _delegate(xp, Backend.TORCH): |
| 165 | + if is_torch_namespace(xp): |
164 | 166 | pad_width = xp.asarray(pad_width) |
165 | 167 | pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) |
166 | 168 | pad_width = xp.flip(pad_width, axis=(0,)).flatten() |
167 | 169 | return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
168 | 170 |
|
169 | | - if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE): |
170 | | - return xp.pad(x, pad_width, mode, constant_values=constant_values) |
171 | | - |
172 | 171 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) |
0 commit comments