@@ -1306,6 +1306,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike,
13061306 Returns:
13071307 The ``operand`` array with padding value ``padding_value`` inserted in each
13081308 dimension according to the ``padding_config``.
1309+
1310+ Examples:
1311+ >>> from jax import lax
1312+ >>> import jax.numpy as jnp
1313+
1314+ Pad a 1-dimensional array with zeros, We'll specify two zeros in front and
1315+ three at the end:
1316+
1317+ >>> x = jnp.array([1, 2, 3, 4])
1318+ >>> lax.pad(x, 0, [(2, 3, 0)])
1319+ Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
1320+
1321+ Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero
1322+ between each value:
1323+
1324+ >>> lax.pad(x, 0, [(0, 0, 1)])
1325+ Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
1326+
1327+ Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad
1328+ size of 2 in each dimension:
1329+
1330+ >>> x = jnp.array([[1, 2, 3],
1331+ ... [4, 5, 6]])
1332+ >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
1333+ Array([[-1, -1, -1, -1, -1, -1, -1],
1334+ [-1, -1, -1, -1, -1, -1, -1],
1335+ [-1, -1, 1, 2, 3, -1, -1],
1336+ [-1, -1, 4, 5, 6, -1, -1],
1337+ [-1, -1, -1, -1, -1, -1, -1],
1338+ [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
13091339 """
13101340 return pad_p .bind (operand , padding_value , padding_config = tuple (padding_config ))
13111341
0 commit comments