@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555
556556def pad (
557557 x : Array ,
558- pad_width : int ,
558+ pad_width : int | tuple | list ,
559559 mode : str = "constant" ,
560560 * ,
561561 xp : ModuleType | None = None ,
@@ -568,8 +568,12 @@ def pad(
568568 ----------
569569 x : array
570570 Input array.
571- pad_width : int
571+ pad_width : int or tuple of ints or list of pairs of ints
572572 Pad the input array with this many elements from each side.
573+ If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574+ each pair applies to the corresponding axis of ``x``.
575+ A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576+ copies of this tuple.
573577 mode : str, optional
574578 Only "constant" mode is currently supported, which pads with
575579 the value passed to `constant_values`.
@@ -590,16 +594,43 @@ def pad(
590594
591595 value = constant_values
592596
597+ # make pad_width a list of length-2 tuples of ints
598+ if isinstance (pad_width , int ):
599+ pad_width = [(pad_width , pad_width )] * x .ndim
600+
601+ if isinstance (pad_width , tuple ):
602+ pad_width = [pad_width ] * x .ndim
603+
593604 if xp is None :
594605 xp = array_namespace (x )
595606
607+ slices = []
608+ newshape = []
609+ for ax , w_tpl in enumerate (pad_width ):
610+ if len (w_tpl ) != 2 :
611+ msg = f"expect a 2-tuple (before, after), got { w_tpl } ."
612+ raise ValueError (msg )
613+
614+ sh = x .shape [ax ]
615+ if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
616+ sl = slice (None , None , None )
617+ else :
618+ start , stop = w_tpl
619+ stop = None if stop == 0 else - stop
620+
621+ sl = slice (start , stop , None )
622+ sh += w_tpl [0 ] + w_tpl [1 ]
623+
624+ newshape .append (sh )
625+ slices .append (sl )
626+
596627 padded = xp .full (
597- tuple (x + 2 * pad_width for x in x . shape ),
628+ tuple (newshape ),
598629 fill_value = value ,
599630 dtype = x .dtype ,
600631 device = _compat .device (x ),
601632 )
602- padded [( slice ( pad_width , - pad_width , None ),) * x . ndim ] = x
633+ padded [tuple ( slices ) ] = x
603634 return padded
604635
605636
0 commit comments