Skip to content

Commit 4a5992c

Browse files
committed
Support pad_width as a dictionary
1 parent 876e940 commit 4a5992c

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

dpnp/dpnp_utils/dpnp_utils_pad.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
273273
return left_stat, right_stat
274274

275275

276+
def _pad_normalize_dict_width(pad_width, ndim):
277+
"""
278+
Normalize pad width passed as a dictionary.
279+
280+
Parameters
281+
----------
282+
pad_width : dict
283+
Padding specification. The keys must be integer axis indices, and
284+
the values must be either:
285+
- a single int (same padding before and after),
286+
- a tuple of two ints (before, after).
287+
ndim : int
288+
Number of dimensions in the input array.
289+
290+
Returns
291+
-------
292+
seq : list
293+
A (ndim, 2) list of padding widths for each axis.
294+
295+
Raises
296+
------
297+
TypeError
298+
If the padding format for any axis is invalid.
299+
300+
"""
301+
302+
seq = [(0, 0)] * ndim
303+
for axis, width in pad_width.items():
304+
if isinstance(width, int):
305+
seq[axis] = (width, width)
306+
elif (
307+
isinstance(width, tuple)
308+
and len(width) == 2
309+
and all(isinstance(w, int) for w in width)
310+
):
311+
seq[axis] = width
312+
else:
313+
raise TypeError(f"Invalid pad width for axis {axis}: {width}")
314+
return seq
315+
316+
276317
def _pad_simple(array, pad_width, fill_value=None):
277318
"""
278319
Copied from numpy/lib/_arraypad_impl.py
@@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
616657
def dpnp_pad(array, pad_width, mode="constant", **kwargs):
617658
"""Pad an array."""
618659

660+
nd = array.ndim
661+
619662
if isinstance(pad_width, int):
620663
if pad_width < 0:
621664
raise ValueError("index can't contain negative values")
622-
pad_width = ((pad_width, pad_width),) * array.ndim
665+
pad_width = ((pad_width, pad_width),) * nd
623666
else:
624667
if dpnp.is_supported_array_type(pad_width):
625668
pad_width = dpnp.asnumpy(pad_width)
626669
else:
670+
if isinstance(pad_width, dict):
671+
pad_width = _pad_normalize_dict_width(pad_width, nd)
627672
pad_width = numpy.asarray(pad_width)
628673

629674
if not pad_width.dtype.kind == "i":
630675
raise TypeError("`pad_width` must be of integral type.")
631676

632-
# Broadcast to shape (array.ndim, 2)
633-
pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
677+
# Broadcast to shape (nd, 2)
678+
pad_width = _as_pairs(pad_width, nd, as_index=True)
634679

635680
if callable(mode):
636681
function = mode
@@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
683728
if (
684729
dpnp.isscalar(values)
685730
and values == 0
686-
and (array.ndim == 1 or array.size < 3e7)
731+
and (nd == 1 or array.size < 3e7)
687732
):
688733
# faster path for 1d arrays or small n-dimensional arrays
689734
return _pad_simple(array, pad_width, 0)[0]

0 commit comments

Comments
 (0)