@@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
273
273
return left_stat , right_stat
274
274
275
275
276
+ def _pad_normalize_dict_width (pad_width , ndim ):
277
+ """
278
+ Normalize pad width passed as a dictionary.
279
+
280
+ Parameters
281
+ ----------
282
+ pad_width : dict
283
+ Padding specification. The keys must be integer axis indices, and
284
+ the values must be either:
285
+ - a single int (same padding before and after),
286
+ - a tuple of two ints (before, after).
287
+ ndim : int
288
+ Number of dimensions in the input array.
289
+
290
+ Returns
291
+ -------
292
+ seq : list
293
+ A (ndim, 2) list of padding widths for each axis.
294
+
295
+ Raises
296
+ ------
297
+ TypeError
298
+ If the padding format for any axis is invalid.
299
+
300
+ """
301
+
302
+ seq = [(0 , 0 )] * ndim
303
+ for axis , width in pad_width .items ():
304
+ if isinstance (width , int ):
305
+ seq [axis ] = (width , width )
306
+ elif (
307
+ isinstance (width , tuple )
308
+ and len (width ) == 2
309
+ and all (isinstance (w , int ) for w in width )
310
+ ):
311
+ seq [axis ] = width
312
+ else :
313
+ raise TypeError (f"Invalid pad width for axis { axis } : { width } " )
314
+ return seq
315
+
316
+
276
317
def _pad_simple (array , pad_width , fill_value = None ):
277
318
"""
278
319
Copied from numpy/lib/_arraypad_impl.py
@@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
616
657
def dpnp_pad (array , pad_width , mode = "constant" , ** kwargs ):
617
658
"""Pad an array."""
618
659
660
+ nd = array .ndim
661
+
619
662
if isinstance (pad_width , int ):
620
663
if pad_width < 0 :
621
664
raise ValueError ("index can't contain negative values" )
622
- pad_width = ((pad_width , pad_width ),) * array . ndim
665
+ pad_width = ((pad_width , pad_width ),) * nd
623
666
else :
624
667
if dpnp .is_supported_array_type (pad_width ):
625
668
pad_width = dpnp .asnumpy (pad_width )
626
669
else :
670
+ if isinstance (pad_width , dict ):
671
+ pad_width = _pad_normalize_dict_width (pad_width , nd )
627
672
pad_width = numpy .asarray (pad_width )
628
673
629
674
if not pad_width .dtype .kind == "i" :
630
675
raise TypeError ("`pad_width` must be of integral type." )
631
676
632
- # Broadcast to shape (array.ndim , 2)
633
- pad_width = _as_pairs (pad_width , array . ndim , as_index = True )
677
+ # Broadcast to shape (nd , 2)
678
+ pad_width = _as_pairs (pad_width , nd , as_index = True )
634
679
635
680
if callable (mode ):
636
681
function = mode
@@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
683
728
if (
684
729
dpnp .isscalar (values )
685
730
and values == 0
686
- and (array . ndim == 1 or array .size < 3e7 )
731
+ and (nd == 1 or array .size < 3e7 )
687
732
):
688
733
# faster path for 1d arrays or small n-dimensional arrays
689
734
return _pad_simple (array , pad_width , 0 )[0 ]
0 commit comments