Skip to content

Commit 334bd4d

Browse files
Merge pull request jax-ml#25019 from jakevdp:lax-pad-doc
PiperOrigin-RevId: 698556681
2 parents 6fe7804 + 2699e95 commit 334bd4d

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

jax/_src/lax/lax.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike,
13061306
Returns:
13071307
The ``operand`` array with padding value ``padding_value`` inserted in each
13081308
dimension according to the ``padding_config``.
1309+
1310+
Examples:
1311+
>>> from jax import lax
1312+
>>> import jax.numpy as jnp
1313+
1314+
Pad a 1-dimensional array with zeros, We'll specify two zeros in front and
1315+
three at the end:
1316+
1317+
>>> x = jnp.array([1, 2, 3, 4])
1318+
>>> lax.pad(x, 0, [(2, 3, 0)])
1319+
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
1320+
1321+
Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero
1322+
between each value:
1323+
1324+
>>> lax.pad(x, 0, [(0, 0, 1)])
1325+
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
1326+
1327+
Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad
1328+
size of 2 in each dimension:
1329+
1330+
>>> x = jnp.array([[1, 2, 3],
1331+
... [4, 5, 6]])
1332+
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
1333+
Array([[-1, -1, -1, -1, -1, -1, -1],
1334+
[-1, -1, -1, -1, -1, -1, -1],
1335+
[-1, -1, 1, 2, 3, -1, -1],
1336+
[-1, -1, 4, 5, 6, -1, -1],
1337+
[-1, -1, -1, -1, -1, -1, -1],
1338+
[-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
13091339
"""
13101340
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
13111341

0 commit comments

Comments
 (0)