Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571)
* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558)
* Improved performance of `dpnp.isclose` function by implementing a dedicated kernel for scalar `rtol` and `atol` arguments [#2540](https://github.com/IntelPython/dpnp/pull/2540)
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)

### Deprecated

Expand Down
25 changes: 24 additions & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,14 +2508,17 @@ def pad(array, pad_width, mode="constant", **kwargs):
----------
array : {dpnp.ndarray, usm_ndarray}
The array of rank ``N`` to pad.
pad_width : {sequence, array_like, int}
pad_width : {sequence, array_like, int, dict}
Number of values padded to the edges of each axis.
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after pad for each axis.
``(pad,)`` or ``int`` is a shortcut for ``before = after = pad`` width
for all axes.
If a dictionary, each key is an axis and its corresponding value is an
integer or a pair of integers describing the padding ``(before, after)``
or ``pad`` width for that axis.
mode : {str, function}, optional
One of the following string values or a user supplied function.

Expand Down Expand Up @@ -2698,6 +2701,26 @@ def pad(array, pad_width, mode="constant", **kwargs):
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100]])

>>> a = np.arange(1, 7).reshape(2, 3)
>>> np.pad(a, {1: (1, 2)})
array([[0, 1, 2, 3, 0, 0],
[0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {-1: 2})
array([[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {0: (3, 0)})
array([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[1, 2, 3],
[4, 5, 6]])
>>> np.pad(a, {0: (3, 0), 1: 2})
array([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])

"""

dpnp.check_supported_arrays_type(array)
Expand Down
53 changes: 49 additions & 4 deletions dpnp/dpnp_utils/dpnp_utils_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
return left_stat, right_stat


def _pad_normalize_dict_width(pad_width, ndim):
"""
Normalize pad width passed as a dictionary.

Parameters
----------
pad_width : dict
Padding specification. The keys must be integer axis indices, and
the values must be either:
- a single int (same padding before and after),
- a tuple of two ints (before, after).
ndim : int
Number of dimensions in the input array.

Returns
-------
seq : list
A (ndim, 2) list of padding widths for each axis.

Raises
------
TypeError
If the padding format for any axis is invalid.

"""

seq = [(0, 0)] * ndim
for axis, width in pad_width.items():
if isinstance(width, int):
seq[axis] = (width, width)
elif (
isinstance(width, tuple)
and len(width) == 2
and all(isinstance(w, int) for w in width)
):
seq[axis] = width
else:
raise TypeError(f"Invalid pad width for axis {axis}: {width}")
return seq


def _pad_simple(array, pad_width, fill_value=None):
"""
Copied from numpy/lib/_arraypad_impl.py
Expand Down Expand Up @@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
def dpnp_pad(array, pad_width, mode="constant", **kwargs):
"""Pad an array."""

nd = array.ndim

if isinstance(pad_width, int):
if pad_width < 0:
raise ValueError("index can't contain negative values")
pad_width = ((pad_width, pad_width),) * array.ndim
pad_width = ((pad_width, pad_width),) * nd
else:
if dpnp.is_supported_array_type(pad_width):
pad_width = dpnp.asnumpy(pad_width)
else:
if isinstance(pad_width, dict):
pad_width = _pad_normalize_dict_width(pad_width, nd)
pad_width = numpy.asarray(pad_width)

if not pad_width.dtype.kind == "i":
raise TypeError("`pad_width` must be of integral type.")

# Broadcast to shape (array.ndim, 2)
pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
# Broadcast to shape (nd, 2)
pad_width = _as_pairs(pad_width, nd, as_index=True)

if callable(mode):
function = mode
Expand Down Expand Up @@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
if (
dpnp.isscalar(values)
and values == 0
and (array.ndim == 1 or array.size < 3e7)
and (nd == 1 or array.size < 3e7)
):
# faster path for 1d arrays or small n-dimensional arrays
return _pad_simple(array, pad_width, 0)[0]
Expand Down
18 changes: 18 additions & 0 deletions dpnp/tests/test_arraypad.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,21 @@ def test_as_pairs_exceptions(self):
dpnp_as_pairs([[1, 2], [3, 4]], 3)
with pytest.raises(ValueError, match="could not be broadcast"):
dpnp_as_pairs(dpnp.ones((2, 3)), 3)

@testing.with_requires("numpy>=2.4")
@pytest.mark.parametrize(
"sh, pad_width",
[
((3, 4, 5), {-2: (1, 3)}),
((3, 4, 5), {0: (5, 2)}),
((3, 4, 5), {0: (5, 2), -1: (3, 4)}),
((3, 4, 5), {1: 5}),
],
)
def test_dict_pad_width(self, sh, pad_width):
a = numpy.zeros(sh)
ia = dpnp.array(a)

result = dpnp.pad(ia, pad_width)
expected = numpy.pad(a, pad_width)
assert_equal(result, expected)
Loading