@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555
556556def pad (
557557 x : Array ,
558- pad_width : int ,
558+ pad_width : int | tuple ,
559559 mode : str = "constant" ,
560560 * ,
561561 xp : ModuleType | None = None ,
@@ -568,8 +568,9 @@ def pad(
568568 ----------
569569 x : array
570570 Input array.
571- pad_width : int
571+ pad_width : int or tuple of ints
572572 Pad the input array with this many elements from each side.
573+ Ifa tuple, each element applies to the corresponding axis of `x`.
573574 mode : str, optional
574575 Only "constant" mode is currently supported, which pads with
575576 the value passed to `constant_values`.
@@ -590,16 +591,23 @@ def pad(
590591
591592 value = constant_values
592593
594+ if isinstance (pad_width , int ):
595+ pad_width = (pad_width ,) * x .ndim
596+
593597 if xp is None :
594598 xp = array_namespace (x )
595599
596600 padded = xp .full (
597- tuple (x + 2 * pad_width for x in x .shape ),
601+ tuple (x + 2 * w for ( x , w ) in zip ( x .shape , pad_width ) ),
598602 fill_value = value ,
599603 dtype = x .dtype ,
600604 device = _compat .device (x ),
601605 )
602- padded [(slice (pad_width , - pad_width , None ),) * x .ndim ] = x
606+ sl = tuple (
607+ slice (w , - w , None ) if w > 0 else slice (None , None , None )
608+ for w in pad_width
609+ )
610+ padded [sl ] = x
603611 return padded
604612
605613
0 commit comments