Skip to content

Commit fcde8aa

Browse files
Merge pull request jax-ml#24896 from hawkinsp:pad
PiperOrigin-RevId: 696568236
2 parents cea8176 + ad5a062 commit fcde8aa

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4048,15 +4048,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):
40484048

40494049
def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array:
40504050
nd = ndim(array)
4051-
constant_values = broadcast_to(constant_values, (nd, 2))
40524051
constant_values = lax_internal._convert_element_type(
40534052
constant_values, array.dtype, dtypes.is_weakly_typed(array))
4053+
constant_values_nd = ndim(constant_values)
4054+
4055+
if constant_values_nd == 0:
4056+
widths = [(low, high, 0) for (low, high) in pad_width]
4057+
return lax.pad(array, constant_values, widths)
4058+
4059+
if constant_values_nd == 1:
4060+
if constant_values.shape[-1] == 1:
4061+
widths = [(low, high, 0) for (low, high) in pad_width]
4062+
return lax.pad(array, squeeze(constant_values), widths)
4063+
elif constant_values.shape[-1] == 2:
4064+
widths = [(low, 0, 0) for (low, _) in pad_width]
4065+
array = lax.pad(array, constant_values[0], widths)
4066+
widths = [(0, high, 0) for (_, high) in pad_width]
4067+
return lax.pad(array, constant_values[1], widths)
4068+
else:
4069+
raise ValueError("jnp.pad: constant_values has unsupported shape "
4070+
f"{constant_values.shape}. If the shape is 1D or 2D, the "
4071+
"last dimension must be of size 1 or 2.")
4072+
4073+
constant_values = broadcast_to(constant_values, (nd, 2))
40544074
for i in range(nd):
40554075
widths = [(0, 0, 0)] * nd
4056-
widths[i] = (pad_width[i][0], 0, 0)
4057-
array = lax.pad(array, constant_values[i, 0], widths)
4058-
widths[i] = (0, pad_width[i][1], 0)
4059-
array = lax.pad(array, constant_values[i, 1], widths)
4076+
if pad_width[i][0] != 0:
4077+
widths[i] = (pad_width[i][0], 0, 0)
4078+
array = lax.pad(array, constant_values[i, 0], widths)
4079+
if pad_width[i][1] != 0:
4080+
widths[i] = (0, pad_width[i][1], 0)
4081+
array = lax.pad(array, constant_values[i, 1], widths)
40604082
return array
40614083

40624084

0 commit comments

Comments
 (0)