@@ -4048,15 +4048,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
40484048
40494049def _pad_constant (array : Array , pad_width : PadValue [int ], constant_values : Array ) -> Array :
40504050 nd = ndim (array )
4051- constant_values = broadcast_to (constant_values , (nd , 2 ))
40524051 constant_values = lax_internal ._convert_element_type (
40534052 constant_values , array .dtype , dtypes .is_weakly_typed (array ))
4053+ constant_values_nd = ndim (constant_values )
4054+
4055+ if constant_values_nd == 0 :
4056+ widths = [(low , high , 0 ) for (low , high ) in pad_width ]
4057+ return lax .pad (array , constant_values , widths )
4058+
4059+ if constant_values_nd == 1 :
4060+ if constant_values .shape [- 1 ] == 1 :
4061+ widths = [(low , high , 0 ) for (low , high ) in pad_width ]
4062+ return lax .pad (array , squeeze (constant_values ), widths )
4063+ elif constant_values .shape [- 1 ] == 2 :
4064+ widths = [(low , 0 , 0 ) for (low , _ ) in pad_width ]
4065+ array = lax .pad (array , constant_values [0 ], widths )
4066+ widths = [(0 , high , 0 ) for (_ , high ) in pad_width ]
4067+ return lax .pad (array , constant_values [1 ], widths )
4068+ else :
4069+ raise ValueError ("jnp.pad: constant_values has unsupported shape "
4070+ f"{ constant_values .shape } . If the shape is 1D or 2D, the "
4071+ "last dimension must be of size 1 or 2." )
4072+
4073+ constant_values = broadcast_to (constant_values , (nd , 2 ))
40544074 for i in range (nd ):
40554075 widths = [(0 , 0 , 0 )] * nd
4056- widths [i ] = (pad_width [i ][0 ], 0 , 0 )
4057- array = lax .pad (array , constant_values [i , 0 ], widths )
4058- widths [i ] = (0 , pad_width [i ][1 ], 0 )
4059- array = lax .pad (array , constant_values [i , 1 ], widths )
4076+ if pad_width [i ][0 ] != 0 :
4077+ widths [i ] = (pad_width [i ][0 ], 0 , 0 )
4078+ array = lax .pad (array , constant_values [i , 0 ], widths )
4079+ if pad_width [i ][1 ] != 0 :
4080+ widths [i ] = (0 , pad_width [i ][1 ], 0 )
4081+ array = lax .pad (array , constant_values [i , 1 ], widths )
40604082 return array
40614083
40624084
0 commit comments