Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad

__version__ = "0.4.1.dev0"

Expand All @@ -14,4 +14,5 @@
"kron",
"setdiff1d",
"sinc",
"pad",
]
56 changes: 55 additions & 1 deletion src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import warnings

from ._lib import _compat, _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace, is_torch_namespace, is_array_api_strict_namespace
)
from ._lib._typing import Array, ModuleType

__all__ = [
Expand All @@ -14,6 +16,7 @@
"kron",
"setdiff1d",
"sinc",
"pad",
]


Expand Down Expand Up @@ -538,3 +541,54 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs):
"""
Pad the input array.

Parameters
----------
x : array
Input array
pad_width: int
Pad the input array with this many elements from each side
mode: str, optional
Only "constant" mode is currently supported.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values: python scalar, optional
Use this value to pad the input. Default is zero.

Returns
-------
array
The input array, padded with ``pad_width`` elements equal to ``constant_values``
"""
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045

if mode != 'constant':
raise NotImplementedError()

value = kwargs.get("constant_values", 0)
if kwargs and list(kwargs.keys()) != ['constant_values']:
raise ValueError(f"Unknown kwargs: {kwargs}")

if xp is None:
xp = array_namespace(x)

if is_array_api_strict_namespace(xp):
padded = xp.full(
tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype
)
padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x
return padded
elif is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=value)

else:
return xp.pad(x, pad_width, mode=mode, **kwargs)
2 changes: 2 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_torch_namespace,
is_array_api_strict_namespace,
)

__all__ = [
Expand Down
22 changes: 22 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
kron,
setdiff1d,
sinc,
pad,
)
from array_api_extra._lib._typing import Array

Expand Down Expand Up @@ -385,3 +386,24 @@ def test_device(self):

def test_xp(self):
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPad:
def test_simple(self):
a = xp.arange(1, 4)
padded = pad(a, 2)
assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0]))

def test_fill_value(self):
a = xp.arange(1, 4)
padded = pad(a, 2, constant_values=42)
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))

def test_ndim(self):
a = xp.reshape(xp.arange(2*3*4), (2, 3, 4))
padded = pad(a, 2)
assert padded.shape == (6, 7, 8)

def test_typo(self):
with pytest.raises(ValueError, match="Unknown"):
pad(xp.arange(2), pad_width=3, oops=3)
Loading