Skip to content

Commit d4d05b0

Browse files
committed
fixes
1 parent 38690bb commit d4d05b0

File tree

4 files changed

+13
-43
lines changed

4 files changed

+13
-43
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Extra array functions built on top of the array API standard."""
22

3+
from ._delegators import pad
34
from ._lib._funcs import (
45
atleast_nd,
56
cov,
67
create_diagonal,
78
expand_dims,
89
kron,
9-
pad,
1010
setdiff1d,
1111
sinc,
1212
)

src/array_api_extra/_delegators.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,18 @@ def pad(
4444
"""
4545
xp = array_namespace(x) if xp is None else xp
4646

47-
value = constant_values
47+
if mode != "constant":
48+
msg = "Only `'constant'` mode is currently supported"
49+
raise NotImplementedError(msg)
4850

4951
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
5052
if is_torch_namespace(xp):
5153
pad_width = xp.asarray(pad_width)
5254
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
5355
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
54-
return xp.nn.functional.pad(x, (pad_width,), value=value)
56+
return xp.nn.functional.pad(x, (pad_width,), value=constant_values)
5557

56-
if is_numpy_namespace(x) or is_jax_namespace(xp) or is_cupy_namespace(xp):
57-
return xp.pad(x, pad_width, mode, constant_values=value)
58+
if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp):
59+
return xp.pad(x, pad_width, mode, constant_values=constant_values)
5860

59-
return _funcs.pad(x, pad_width, mode, constant_values=value, xp=xp)
61+
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -544,46 +544,14 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
544544
def pad(
545545
x: Array,
546546
pad_width: int,
547-
mode: str = "constant",
548547
*,
549548
constant_values: bool | int | float | complex = 0,
550-
xp: ModuleType | None = None,
551-
) -> Array:
552-
"""
553-
Pad the input array.
554-
555-
Parameters
556-
----------
557-
x : array
558-
Input array.
559-
pad_width : int
560-
Pad the input array with this many elements from each side.
561-
mode : str, optional
562-
Only "constant" mode is currently supported, which pads with
563-
the value passed to `constant_values`.
564-
constant_values : python scalar, optional
565-
Use this value to pad the input. Default is zero.
566-
xp : array_namespace, optional
567-
The standard-compatible namespace for `x`. Default: infer.
568-
569-
Returns
570-
-------
571-
array
572-
The input array,
573-
padded with ``pad_width`` elements equal to ``constant_values``.
574-
"""
575-
if mode != "constant":
576-
msg = "Only `'constant'` mode is currently supported"
577-
raise NotImplementedError(msg)
578-
579-
value = constant_values
580-
581-
if xp is None:
582-
xp = array_namespace(x)
583-
549+
xp: ModuleType,
550+
) -> Array: # numpydoc ignore=PR01,RT01
551+
"""See docstring in `_delegators.py`."""
584552
padded = xp.full(
585553
tuple(x + 2 * pad_width for x in x.shape),
586-
fill_value=value,
554+
fill_value=constant_values,
587555
dtype=x.dtype,
588556
device=_compat.device(x),
589557
)

src/array_api_extra/_lib/_utils/_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# `array-api-compat` to override the import location
44

55
try:
6-
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
6+
from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
99
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]

0 commit comments

Comments
 (0)