Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_diagonal
expand_dims
kron
pad
setdiff1d
sinc
```
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import (
from ._delegators import pad
from ._lib._funcs import (
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
pad,
setdiff1d,
sinc,
)
Expand Down
61 changes: 61 additions & 0 deletions src/array_api_extra/_delegators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Delegators to existing implementations for Public API Functions."""

from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_jax_namespace,
is_numpy_namespace,
is_torch_namespace,
)
from ._lib._utils._typing import Array, ModuleType


def pad(
x: Array,
pad_width: int,
mode: str = "constant",
*,
constant_values: bool | int | float | complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int
Pad the input array with this many elements from each side.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp

if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, (pad_width,), value=constant_values)

Check warning on line 56 in src/array_api_extra/_delegators.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_delegators.py#L53-L56

Added lines #L53 - L56 were not covered by tests

if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
return xp.pad(x, pad_width, mode, constant_values=constant_values)

Check warning on line 59 in src/array_api_extra/_delegators.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_delegators.py#L59

Added line #L59 was not covered by tests

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Modules housing private functions."""
"""Array-agnostic implementations for the public API."""
19 changes: 0 additions & 19 deletions src/array_api_extra/_lib/_compat.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import warnings

from ._lib import _compat, _utils
from ._lib._compat import array_namespace
from ._lib._typing import Array, ModuleType
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace
from ._utils._typing import Array, ModuleType

__all__ = [
"atleast_nd",
Expand Down Expand Up @@ -136,7 +136,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _utils.mean(m, axis=1, xp=xp)
avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand Down Expand Up @@ -449,7 +449,7 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down Expand Up @@ -544,46 +544,14 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
def pad(
x: Array,
pad_width: int,
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
) -> Array:
"""
Pad the input array.

Parameters
----------
x : array
Input array.
pad_width : int
Pad the input array with this many elements from each side.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.

Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

value = constant_values

if xp is None:
xp = array_namespace(x)

xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `_delegators.py`."""
padded = xp.full(
tuple(x + 2 * pad_width for x in x.shape),
fill_value=value,
fill_value=constant_values,
dtype=x.dtype,
device=_compat.device(x),
)
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules housing private utility functions."""
31 changes: 31 additions & 0 deletions src/array_api_extra/_lib/_utils/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Acquire helpers from array-api-compat."""
# Allow packages that vendor both `array-api-extra` and
# `array-api-compat` to override the import location

try:
from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device, # pyright: ignore[reportUnknownVariableType]
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
)

__all__ = [
"array_namespace",
"device",
"is_cupy_namespace",
"is_jax_namespace",
"is_numpy_namespace",
"is_torch_namespace",
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ... # numpydoc ignore=GL08
def device(x: Array, /) -> Device: ... # numpydoc ignore=GL08
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Utility functions used by `array_api_extra/_funcs.py`."""
"""Helper functions used by `array_api_extra/_funcs.py`."""

from . import _compat
from ._typing import Array, ModuleType
Expand Down
2 changes: 1 addition & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
setdiff1d,
sinc,
)
from array_api_extra._lib._typing import Array
from array_api_extra._lib._utils._typing import Array


class TestAtLeastND:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pytest
from numpy.testing import assert_array_equal

from array_api_extra._lib._typing import Array
from array_api_extra._lib._utils import in1d
from array_api_extra._lib._utils._helpers import in1d
from array_api_extra._lib._utils._typing import Array


# some test coverage already provided by TestSetDiff1D
Expand Down
2 changes: 1 addition & 1 deletion vendor_tests/test_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ def test_vendor_extra():

def test_vendor_extra_uses_vendor_compat():
from ._array_api_compat_vendor import array_namespace as n1
from .array_api_extra._lib._compat import array_namespace as n2
from .array_api_extra._lib._utils._compat import array_namespace as n2

assert n1 is n2
Loading