Skip to content

Commit 1daf4c9

Browse files
authored
Extend dpnp.pad to support pad_width as a dictionary (#2535)
The PR adds support of `pad_width` keyword as a dictionary to align with the behavior in `numpy.pad` function.
1 parent 9a9f5d0 commit 1daf4c9

File tree

4 files changed

+92
-5
lines changed

4 files changed

+92
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3939
* Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571)
4040
* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558)
4141
* Improved performance of `dpnp.isclose` function by implementing a dedicated kernel for scalar `rtol` and `atol` arguments [#2540](https://github.com/IntelPython/dpnp/pull/2540)
42+
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)
4243

4344
### Deprecated
4445

dpnp/dpnp_iface_manipulation.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2508,14 +2508,17 @@ def pad(array, pad_width, mode="constant", **kwargs):
25082508
----------
25092509
array : {dpnp.ndarray, usm_ndarray}
25102510
The array of rank ``N`` to pad.
2511-
pad_width : {sequence, array_like, int}
2511+
pad_width : {sequence, array_like, int, dict}
25122512
Number of values padded to the edges of each axis.
25132513
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
25142514
for each axis.
25152515
``(before, after)`` or ``((before, after),)`` yields same before
25162516
and after pad for each axis.
25172517
``(pad,)`` or ``int`` is a shortcut for ``before = after = pad`` width
25182518
for all axes.
2519+
If a dictionary, each key is an axis and its corresponding value is an
2520+
integer or a pair of integers describing the padding ``(before, after)``
2521+
or ``pad`` width for that axis.
25192522
mode : {str, function}, optional
25202523
One of the following string values or a user supplied function.
25212524
@@ -2698,6 +2701,26 @@ def pad(array, pad_width, mode="constant", **kwargs):
26982701
[100, 100, 100, 100, 100, 100, 100],
26992702
[100, 100, 100, 100, 100, 100, 100]])
27002703
2704+
>>> a = np.arange(1, 7).reshape(2, 3)
2705+
>>> np.pad(a, {1: (1, 2)})
2706+
array([[0, 1, 2, 3, 0, 0],
2707+
[0, 4, 5, 6, 0, 0]])
2708+
>>> np.pad(a, {-1: 2})
2709+
array([[0, 0, 1, 2, 3, 0, 0],
2710+
[0, 0, 4, 5, 6, 0, 0]])
2711+
>>> np.pad(a, {0: (3, 0)})
2712+
array([[0, 0, 0],
2713+
[0, 0, 0],
2714+
[0, 0, 0],
2715+
[1, 2, 3],
2716+
[4, 5, 6]])
2717+
>>> np.pad(a, {0: (3, 0), 1: 2})
2718+
array([[0, 0, 0, 0, 0, 0, 0],
2719+
[0, 0, 0, 0, 0, 0, 0],
2720+
[0, 0, 0, 0, 0, 0, 0],
2721+
[0, 0, 1, 2, 3, 0, 0],
2722+
[0, 0, 4, 5, 6, 0, 0]])
2723+
27012724
"""
27022725

27032726
dpnp.check_supported_arrays_type(array)

dpnp/dpnp_utils/dpnp_utils_pad.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
273273
return left_stat, right_stat
274274

275275

276+
def _pad_normalize_dict_width(pad_width, ndim):
277+
"""
278+
Normalize pad width passed as a dictionary.
279+
280+
Parameters
281+
----------
282+
pad_width : dict
283+
Padding specification. The keys must be integer axis indices, and
284+
the values must be either:
285+
- a single int (same padding before and after),
286+
- a tuple of two ints (before, after).
287+
ndim : int
288+
Number of dimensions in the input array.
289+
290+
Returns
291+
-------
292+
seq : list
293+
A (ndim, 2) list of padding widths for each axis.
294+
295+
Raises
296+
------
297+
TypeError
298+
If the padding format for any axis is invalid.
299+
300+
"""
301+
302+
seq = [(0, 0)] * ndim
303+
for axis, width in pad_width.items():
304+
if isinstance(width, int):
305+
seq[axis] = (width, width)
306+
elif (
307+
isinstance(width, tuple)
308+
and len(width) == 2
309+
and all(isinstance(w, int) for w in width)
310+
):
311+
seq[axis] = width
312+
else:
313+
raise TypeError(f"Invalid pad width for axis {axis}: {width}")
314+
return seq
315+
316+
276317
def _pad_simple(array, pad_width, fill_value=None):
277318
"""
278319
Copied from numpy/lib/_arraypad_impl.py
@@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
616657
def dpnp_pad(array, pad_width, mode="constant", **kwargs):
617658
"""Pad an array."""
618659

660+
nd = array.ndim
661+
619662
if isinstance(pad_width, int):
620663
if pad_width < 0:
621664
raise ValueError("index can't contain negative values")
622-
pad_width = ((pad_width, pad_width),) * array.ndim
665+
pad_width = ((pad_width, pad_width),) * nd
623666
else:
624667
if dpnp.is_supported_array_type(pad_width):
625668
pad_width = dpnp.asnumpy(pad_width)
626669
else:
670+
if isinstance(pad_width, dict):
671+
pad_width = _pad_normalize_dict_width(pad_width, nd)
627672
pad_width = numpy.asarray(pad_width)
628673

629674
if not pad_width.dtype.kind == "i":
630675
raise TypeError("`pad_width` must be of integral type.")
631676

632-
# Broadcast to shape (array.ndim, 2)
633-
pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
677+
# Broadcast to shape (nd, 2)
678+
pad_width = _as_pairs(pad_width, nd, as_index=True)
634679

635680
if callable(mode):
636681
function = mode
@@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
683728
if (
684729
dpnp.isscalar(values)
685730
and values == 0
686-
and (array.ndim == 1 or array.size < 3e7)
731+
and (nd == 1 or array.size < 3e7)
687732
):
688733
# faster path for 1d arrays or small n-dimensional arrays
689734
return _pad_simple(array, pad_width, 0)[0]

dpnp/tests/test_arraypad.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,21 @@ def test_as_pairs_exceptions(self):
539539
dpnp_as_pairs([[1, 2], [3, 4]], 3)
540540
with pytest.raises(ValueError, match="could not be broadcast"):
541541
dpnp_as_pairs(dpnp.ones((2, 3)), 3)
542+
543+
@testing.with_requires("numpy>=2.4")
544+
@pytest.mark.parametrize(
545+
"sh, pad_width",
546+
[
547+
((3, 4, 5), {-2: (1, 3)}),
548+
((3, 4, 5), {0: (5, 2)}),
549+
((3, 4, 5), {0: (5, 2), -1: (3, 4)}),
550+
((3, 4, 5), {1: 5}),
551+
],
552+
)
553+
def test_dict_pad_width(self, sh, pad_width):
554+
a = numpy.zeros(sh)
555+
ia = dpnp.array(a)
556+
557+
result = dpnp.pad(ia, pad_width)
558+
expected = numpy.pad(a, pad_width)
559+
assert_equal(result, expected)

0 commit comments

Comments
 (0)