Skip to content

Commit e0046c5

Browse files
committed
combine enums
1 parent 1e59fbd commit e0046c5

File tree

8 files changed

+85
-77
lines changed

8 files changed

+85
-77
lines changed

src/array_api_extra/_delegation.py

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,15 @@
11
"""Delegation to existing implementations for Public API Functions."""
22

3-
import functools
4-
from enum import Enum
53
from types import ModuleType
6-
from typing import final
74

8-
from ._lib import _funcs
9-
from ._lib._utils._compat import (
10-
array_namespace,
11-
is_cupy_namespace,
12-
is_jax_namespace,
13-
is_numpy_namespace,
14-
is_torch_namespace,
15-
)
5+
from ._lib import Library, _funcs
6+
from ._lib._utils._compat import array_namespace
167
from ._lib._utils._typing import Array
178

189
__all__ = ["pad"]
1910

2011

21-
@final
22-
class IsNamespace(Enum):
23-
"""Enum to access is_namespace functions as the backend."""
24-
25-
# TODO: when Python 3.10 is dropped, use `enum.member`
26-
# https://stackoverflow.com/a/74302109
27-
CUPY = functools.partial(is_cupy_namespace)
28-
JAX = functools.partial(is_jax_namespace)
29-
NUMPY = functools.partial(is_numpy_namespace)
30-
TORCH = functools.partial(is_torch_namespace)
31-
32-
def __call__(self, xp: ModuleType) -> bool:
33-
"""
34-
Call the is_namespace function.
35-
36-
Parameters
37-
----------
38-
xp : array_namespace
39-
Array namespace to check.
40-
41-
Returns
42-
-------
43-
bool
44-
``True`` if xp matches the namespace, ``False`` otherwise.
45-
"""
46-
return self.value(xp)
47-
48-
49-
CUPY = IsNamespace.CUPY
50-
JAX = IsNamespace.JAX
51-
NUMPY = IsNamespace.NUMPY
52-
TORCH = IsNamespace.TORCH
53-
54-
55-
def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool:
12+
def _delegate(xp: ModuleType, *backends: Library) -> bool:
5613
"""
5714
Check whether `xp` is one of the `backends` to delegate to.
5815
@@ -68,7 +25,7 @@ def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool:
6825
bool
6926
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
7027
"""
71-
return any(is_namespace(xp) for is_namespace in backends)
28+
return any(backend.is_namespace(xp) for backend in backends)
7229

7330

7431
def pad(
@@ -113,13 +70,13 @@ def pad(
11370
raise NotImplementedError(msg)
11471

11572
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
116-
if _delegate(xp, TORCH):
73+
if _delegate(xp, Library.TORCH):
11774
pad_width = xp.asarray(pad_width)
11875
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
11976
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
12077
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
12178

122-
if _delegate(xp, NUMPY, JAX, CUPY):
79+
if _delegate(xp, Library.NUMPY, Library.JAX_NUMPY, Library.CUPY):
12380
return xp.pad(x, pad_width, mode, constant_values=constant_values)
12481

12582
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
"""Internals of array-api-extra."""
2+
3+
from ._libraries import Library
4+
5+
__all__ = ["Library"]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Code specifying libraries array-api-extra interacts with."""
2+
3+
from collections.abc import Callable
4+
from enum import Enum
5+
from types import ModuleType
6+
from typing import cast
7+
8+
from ._utils import _compat
9+
10+
__all__ = ["Library"]
11+
12+
13+
class Library(Enum): # numpydoc ignore=PR01,PR02
14+
"""
15+
All array library backends explicitly tested by array-api-extra.
16+
17+
Parameters
18+
----------
19+
value : str
20+
String describing the backend.
21+
library_name : str
22+
Name of the array library of the backend.
23+
module_name : str
24+
Name of the backend's module.
25+
"""
26+
27+
ARRAY_API_STRICT = "array_api_strict", "array_api_strict", "array_api_strict"
28+
NUMPY = "numpy", "numpy", "numpy"
29+
NUMPY_READONLY = "numpy_readonly", "numpy", "numpy"
30+
CUPY = "cupy", "cupy", "cupy"
31+
TORCH = "torch", "torch", "torch"
32+
DASK_ARRAY = "dask.array", "dask", "dask.array"
33+
SPARSE = "sparse", "pydata_sparse", "sparse"
34+
JAX_NUMPY = "jax.numpy", "jax", "jax.numpy"
35+
36+
def __new__(
37+
cls, value: str, _library_name: str, _module_name: str
38+
): # numpydoc ignore=GL08
39+
obj = object.__new__(cls)
40+
obj._value_ = value
41+
return obj
42+
43+
def __init__(
44+
self, _value: str, library_name: str, module_name: str
45+
): # numpydoc ignore=GL08
46+
self.library_name = library_name
47+
self.module_name = module_name
48+
49+
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
50+
"""Pretty-print parameterized test names."""
51+
return cast(str, self.value)
52+
53+
def is_namespace(self, xp: ModuleType) -> bool:
54+
"""
55+
Call the corresponding is_namespace function.
56+
57+
Parameters
58+
----------
59+
xp : array_namespace
60+
Array namespace to check.
61+
62+
Returns
63+
-------
64+
bool
65+
``True`` if xp matches the namespace, ``False`` otherwise.
66+
"""
67+
is_namespace_func = getattr(_compat, f"is_{self.library_name}_namespace")
68+
is_namespace_func = cast(Callable[[ModuleType], bool], is_namespace_func)
69+
return is_namespace_func(xp)

tests/conftest.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,16 @@
11
"""Pytest fixtures."""
22

3-
from enum import Enum
43
from types import ModuleType
54
from typing import cast
65

76
import pytest
87

8+
from array_api_extra._lib import Library
99
from array_api_extra._lib._utils._compat import array_namespace
1010
from array_api_extra._lib._utils._compat import device as get_device
1111
from array_api_extra._lib._utils._typing import Device
1212

1313

14-
class Library(Enum):
15-
"""All array libraries explicitly tested by array-api-extra."""
16-
17-
ARRAY_API_STRICT = "array_api_strict"
18-
NUMPY = "numpy"
19-
NUMPY_READONLY = "numpy_readonly"
20-
CUPY = "cupy"
21-
TORCH = "torch"
22-
DASK_ARRAY = "dask.array"
23-
SPARSE = "sparse"
24-
JAX_NUMPY = "jax.numpy"
25-
26-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
27-
"""Pretty-print parameterized test names."""
28-
return self.value
29-
30-
3114
@pytest.fixture(params=tuple(Library))
3215
def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,RT03
3316
"""
@@ -60,8 +43,7 @@ def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
6043
-------
6144
The current array namespace.
6245
"""
63-
name = "numpy" if library == Library.NUMPY_READONLY else library.value
64-
xp = pytest.importorskip(name)
46+
xp = pytest.importorskip(library.module_name)
6547
if library == Library.JAX_NUMPY:
6648
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
6749

tests/test_at.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
)
1313

1414
from array_api_extra import at
15+
from array_api_extra._lib import Library
1516
from array_api_extra._lib._funcs import _AtOp
1617
from array_api_extra._lib._testing import xp_assert_equal
1718
from array_api_extra._lib._utils._typing import Array
1819

19-
from .conftest import Library
20-
2120

2221
@pytest.fixture
2322
def array(library: Library, xp: ModuleType) -> Array:

tests/test_funcs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
setdiff1d,
1818
sinc,
1919
)
20+
from array_api_extra._lib import Library
2021
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2122
from array_api_extra._lib._utils._compat import device as get_device
2223
from array_api_extra._lib._utils._typing import Array, Device
2324

24-
from .conftest import Library
25-
2625
# mypy: disable-error-code=no-untyped-usage
2726

2827

tests/test_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import numpy as np
22
import pytest
33

4+
from array_api_extra._lib import Library
45
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
56

6-
from .conftest import Library
7-
87
# mypy: disable-error-code=no-any-decorated
98
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
109

tests/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import numpy as np
44
import pytest
55

6+
from array_api_extra._lib import Library
67
from array_api_extra._lib._testing import xp_assert_equal
78
from array_api_extra._lib._utils._compat import device as get_device
89
from array_api_extra._lib._utils._helpers import in1d
910
from array_api_extra._lib._utils._typing import Array, Device
1011

11-
from .conftest import Library
12-
1312
# mypy: disable-error-code=no-untyped-usage
1413

1514

0 commit comments

Comments
 (0)