11"""Delegation to existing implementations for Public API Functions."""
22
3- import functools
4- from enum import Enum
53from types import ModuleType
6- from typing import final
74
8- from ._lib import _funcs
9- from ._lib ._utils ._compat import (
10- array_namespace ,
11- is_cupy_namespace ,
12- is_jax_namespace ,
13- is_numpy_namespace ,
14- is_torch_namespace ,
15- )
5+ from ._lib import Library , _funcs
6+ from ._lib ._utils ._compat import array_namespace
167from ._lib ._utils ._typing import Array
178
189__all__ = ["pad" ]
1910
2011
21- @final
22- class IsNamespace (Enum ):
23- """Enum to access is_namespace functions as the backend."""
24-
25- # TODO: when Python 3.10 is dropped, use `enum.member`
26- # https://stackoverflow.com/a/74302109
27- CUPY = functools .partial (is_cupy_namespace )
28- JAX = functools .partial (is_jax_namespace )
29- NUMPY = functools .partial (is_numpy_namespace )
30- TORCH = functools .partial (is_torch_namespace )
31-
32- def __call__ (self , xp : ModuleType ) -> bool :
33- """
34- Call the is_namespace function.
35-
36- Parameters
37- ----------
38- xp : array_namespace
39- Array namespace to check.
40-
41- Returns
42- -------
43- bool
44- ``True`` if xp matches the namespace, ``False`` otherwise.
45- """
46- return self .value (xp )
47-
48-
49- CUPY = IsNamespace .CUPY
50- JAX = IsNamespace .JAX
51- NUMPY = IsNamespace .NUMPY
52- TORCH = IsNamespace .TORCH
53-
54-
55- def _delegate (xp : ModuleType , * backends : IsNamespace ) -> bool :
12+ def _delegate (xp : ModuleType , * backends : Library ) -> bool :
5613 """
5714 Check whether `xp` is one of the `backends` to delegate to.
5815
@@ -68,7 +25,7 @@ def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool:
6825 bool
6926 ``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
7027 """
71- return any (is_namespace (xp ) for is_namespace in backends )
28+ return any (backend . is_namespace (xp ) for backend in backends )
7229
7330
7431def pad (
@@ -113,13 +70,13 @@ def pad(
11370 raise NotImplementedError (msg )
11471
11572 # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
116- if _delegate (xp , TORCH ):
73+ if _delegate (xp , Library . TORCH ):
11774 pad_width = xp .asarray (pad_width )
11875 pad_width = xp .broadcast_to (pad_width , (x .ndim , 2 ))
11976 pad_width = xp .flip (pad_width , axis = (0 ,)).flatten ()
12077 return xp .nn .functional .pad (x , tuple (pad_width ), value = constant_values ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
12178
122- if _delegate (xp , NUMPY , JAX , CUPY ):
79+ if _delegate (xp , Library . NUMPY , Library . JAX_NUMPY , Library . CUPY ):
12380 return xp .pad (x , pad_width , mode , constant_values = constant_values )
12481
12582 return _funcs .pad (x , pad_width , constant_values = constant_values , xp = xp )
0 commit comments