Skip to content

Commit 2bd8205

Browse files
committed
green
1 parent e0046c5 commit 2bd8205

File tree

10 files changed

+85
-115
lines changed

10 files changed

+85
-115
lines changed

pixi.lock

Lines changed: 39 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
6464
[tool.pixi.dependencies]
6565
python = ">=3.10,<3.14"
6666
array-api-compat = ">=1.10.0,<2"
67+
aenum = ">=3.1.15,<4"
6768

6869
[tool.pixi.pypi-dependencies]
6970
array-api-extra = { path = ".", editable = true }

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from types import ModuleType
44

5-
from ._lib import Library, _funcs
5+
from ._lib import Backend, _funcs
66
from ._lib._utils._compat import array_namespace
77
from ._lib._utils._typing import Array
88

99
__all__ = ["pad"]
1010

1111

12-
def _delegate(xp: ModuleType, *backends: Library) -> bool:
12+
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
1313
"""
1414
Check whether `xp` is one of the `backends` to delegate to.
1515
@@ -70,13 +70,13 @@ def pad(
7070
raise NotImplementedError(msg)
7171

7272
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
73-
if _delegate(xp, Library.TORCH):
73+
if _delegate(xp, Backend.TORCH):
7474
pad_width = xp.asarray(pad_width)
7575
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
7676
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
7777
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
7878

79-
if _delegate(xp, Library.NUMPY, Library.JAX_NUMPY, Library.CUPY):
79+
if _delegate(xp, Backend.NUMPY, Backend.JAX_NUMPY, Backend.CUPY):
8080
return xp.pad(x, pad_width, mode, constant_values=constant_values)
8181

8282
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Internals of array-api-extra."""
22

3-
from ._libraries import Library
3+
from ._backends import Backend
44

5-
__all__ = ["Library"]
5+
__all__ = ["Backend"]

src/array_api_extra/_lib/_libraries.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)