@@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
273273 return left_stat , right_stat
274274
275275
276+ def _pad_normalize_dict_width (pad_width , ndim ):
277+ """
278+ Normalize pad width passed as a dictionary.
279+
280+ Parameters
281+ ----------
282+ pad_width : dict
283+ Padding specification. The keys must be integer axis indices, and
284+ the values must be either:
285+ - a single int (same padding before and after),
286+ - a tuple of two ints (before, after).
287+ ndim : int
288+ Number of dimensions in the input array.
289+
290+ Returns
291+ -------
292+ seq : list
293+ A (ndim, 2) list of padding widths for each axis.
294+
295+ Raises
296+ ------
297+ TypeError
298+ If the padding format for any axis is invalid.
299+
300+ """
301+
302+ seq = [(0 , 0 )] * ndim
303+ for axis , width in pad_width .items ():
304+ if isinstance (width , int ):
305+ seq [axis ] = (width , width )
306+ elif (
307+ isinstance (width , tuple )
308+ and len (width ) == 2
309+ and all (isinstance (w , int ) for w in width )
310+ ):
311+ seq [axis ] = width
312+ else :
313+ raise TypeError (f"Invalid pad width for axis { axis } : { width } " )
314+ return seq
315+
316+
276317def _pad_simple (array , pad_width , fill_value = None ):
277318 """
278319 Copied from numpy/lib/_arraypad_impl.py
@@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
616657def dpnp_pad (array , pad_width , mode = "constant" , ** kwargs ):
617658 """Pad an array."""
618659
660+ nd = array .ndim
661+
619662 if isinstance (pad_width , int ):
620663 if pad_width < 0 :
621664 raise ValueError ("index can't contain negative values" )
622- pad_width = ((pad_width , pad_width ),) * array . ndim
665+ pad_width = ((pad_width , pad_width ),) * nd
623666 else :
624667 if dpnp .is_supported_array_type (pad_width ):
625668 pad_width = dpnp .asnumpy (pad_width )
626669 else :
670+ if isinstance (pad_width , dict ):
671+ pad_width = _pad_normalize_dict_width (pad_width , nd )
627672 pad_width = numpy .asarray (pad_width )
628673
629674 if not pad_width .dtype .kind == "i" :
630675 raise TypeError ("`pad_width` must be of integral type." )
631676
632- # Broadcast to shape (array.ndim , 2)
633- pad_width = _as_pairs (pad_width , array . ndim , as_index = True )
677+ # Broadcast to shape (nd , 2)
678+ pad_width = _as_pairs (pad_width , nd , as_index = True )
634679
635680 if callable (mode ):
636681 function = mode
@@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
683728 if (
684729 dpnp .isscalar (values )
685730 and values == 0
686- and (array . ndim == 1 or array .size < 3e7 )
731+ and (nd == 1 or array .size < 3e7 )
687732 ):
688733 # faster path for 1d arrays or small n-dimensional arrays
689734 return _pad_simple (array , pad_width , 0 )[0 ]
0 commit comments