@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555
556556def pad (
557557 x : Array ,
558- pad_width : int ,
558+ pad_width : int | tuple | list ,
559559 mode : str = "constant" ,
560560 * ,
561561 xp : ModuleType | None = None ,
@@ -568,8 +568,12 @@ def pad(
568568 ----------
569569 x : array
570570 Input array.
571- pad_width : int
571+ pad_width : int or tuple of ints or list of pairs of ints
572572 Pad the input array with this many elements from each side.
573+ If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574+ each pair applies to the corresponding axis of ``x``.
575+ A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576+ copies of this tuple.
573577 mode : str, optional
574578 Only "constant" mode is currently supported, which pads with
575579 the value passed to `constant_values`.
@@ -590,16 +594,45 @@ def pad(
590594
591595 value = constant_values
592596
597+ # make pad_width a list of length-2 tuples of ints
598+ if isinstance (pad_width , int ):
599+ pad_width = [(pad_width , pad_width )] * x .ndim
600+
601+ if isinstance (pad_width , tuple ):
602+ pad_width = [pad_width ] * x .ndim
603+
593604 if xp is None :
594605 xp = array_namespace (x )
595606
607+ slices = []
608+ newshape = []
609+ for ax , w_tpl in enumerate (pad_width ):
610+ if len (w_tpl ) != 2 :
611+ raise ValueError (f"expect a 2-tuple (before, after), got { w_tpl } ." )
612+
613+ sh = x .shape [ax ]
614+ if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
615+ sl = slice (None , None , None )
616+ else :
617+ start , stop = w_tpl
618+ if stop == 0 :
619+ stop = None
620+ else :
621+ stop = - stop
622+
623+ sl = slice (start , stop , None )
624+ sh += w_tpl [0 ] + w_tpl [1 ]
625+
626+ newshape .append (sh )
627+ slices .append (sl )
628+
596629 padded = xp .full (
597- tuple (x + 2 * pad_width for x in x . shape ),
630+ tuple (newshape ),
598631 fill_value = value ,
599632 dtype = x .dtype ,
600633 device = _compat .device (x ),
601634 )
602- padded [( slice ( pad_width , - pad_width , None ),) * x . ndim ] = x
635+ padded [tuple ( slices ) ] = x
603636 return padded
604637
605638
0 commit comments