Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546)
* Changed Windows-specific logic in dpnp initialization [#2553](https://github.com/IntelPython/dpnp/pull/2553)
* Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571)
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)

### Deprecated

Expand Down
25 changes: 24 additions & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,14 +2508,17 @@ def pad(array, pad_width, mode="constant", **kwargs):
----------
array : {dpnp.ndarray, usm_ndarray}
The array of rank ``N`` to pad.
pad_width : {sequence, array_like, int}
pad_width : {sequence, array_like, int, dict}
Number of values padded to the edges of each axis.
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after pad for each axis.
``(pad,)`` or ``int`` is a shortcut for ``before = after = pad`` width
for all axes.
If a dictionary, each key is an axis and its corresponding value is an
integer or a pair of integers describing the padding ``(before, after)``
or ``pad`` width for that axis.
mode : {str, function}, optional
One of the following string values or a user supplied function.

Expand Down Expand Up @@ -2698,6 +2701,26 @@ def pad(array, pad_width, mode="constant", **kwargs):
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100]])

>>> a = np.arange(1, 7).reshape(2, 3)
>>> np.pad(a, {1: (1, 2)})
array([[0, 1, 2, 3, 0, 0],
[0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {-1: 2})
array([[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {0: (3, 0)})
array([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[1, 2, 3],
[4, 5, 6]])
>>> np.pad(a, {0: (3, 0), 1: 2})
array([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])

"""

dpnp.check_supported_arrays_type(array)
Expand Down
53 changes: 49 additions & 4 deletions dpnp/dpnp_utils/dpnp_utils_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
return left_stat, right_stat


def _pad_normalize_dict_width(pad_width, ndim):
"""
Normalize pad width passed as a dictionary.

Parameters
----------
pad_width : dict
Padding specification. The keys must be integer axis indices, and
the values must be either:
- a single int (same padding before and after),
- a tuple of two ints (before, after).
ndim : int
Number of dimensions in the input array.

Returns
-------
seq : list
A (ndim, 2) list of padding widths for each axis.

Raises
------
TypeError
If the padding format for any axis is invalid.

"""

seq = [(0, 0)] * ndim
for axis, width in pad_width.items():
if isinstance(width, int):
seq[axis] = (width, width)
elif (
isinstance(width, tuple)
and len(width) == 2
and all(isinstance(w, int) for w in width)
):
seq[axis] = width
else:
raise TypeError(f"Invalid pad width for axis {axis}: {width}")
return seq


def _pad_simple(array, pad_width, fill_value=None):
"""
Copied from numpy/lib/_arraypad_impl.py
Expand Down Expand Up @@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
def dpnp_pad(array, pad_width, mode="constant", **kwargs):
"""Pad an array."""

nd = array.ndim

if isinstance(pad_width, int):
if pad_width < 0:
raise ValueError("index can't contain negative values")
pad_width = ((pad_width, pad_width),) * array.ndim
pad_width = ((pad_width, pad_width),) * nd
else:
if dpnp.is_supported_array_type(pad_width):
pad_width = dpnp.asnumpy(pad_width)
else:
if isinstance(pad_width, dict):
pad_width = _pad_normalize_dict_width(pad_width, nd)
pad_width = numpy.asarray(pad_width)

if not pad_width.dtype.kind == "i":
raise TypeError("`pad_width` must be of integral type.")

# Broadcast to shape (array.ndim, 2)
pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
# Broadcast to shape (nd, 2)
pad_width = _as_pairs(pad_width, nd, as_index=True)

if callable(mode):
function = mode
Expand Down Expand Up @@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
if (
dpnp.isscalar(values)
and values == 0
and (array.ndim == 1 or array.size < 3e7)
and (nd == 1 or array.size < 3e7)
):
# faster path for 1d arrays or small n-dimensional arrays
return _pad_simple(array, pad_width, 0)[0]
Expand Down
18 changes: 18 additions & 0 deletions dpnp/tests/test_arraypad.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,21 @@ def test_as_pairs_exceptions(self):
dpnp_as_pairs([[1, 2], [3, 4]], 3)
with pytest.raises(ValueError, match="could not be broadcast"):
dpnp_as_pairs(dpnp.ones((2, 3)), 3)

@testing.with_requires("numpy>=2.4")
@pytest.mark.parametrize(
"sh, pad_width",
[
((3, 4, 5), {-2: (1, 3)}),
((3, 4, 5), {0: (5, 2)}),
((3, 4, 5), {0: (5, 2), -1: (3, 4)}),
((3, 4, 5), {1: 5}),
],
)
def test_dict_pad_width(self, sh, pad_width):
a = numpy.zeros(sh)
ia = dpnp.array(a)

result = dpnp.pad(ia, pad_width)
expected = numpy.pad(a, pad_width)
assert_equal(result, expected)
Loading