Skip to content

Commit 25cc84b

Browse files
author
jax authors
committed
Merge pull request #21615 from selamw1:append_doc
PiperOrigin-RevId: 641344856
2 parents dfc6076 + d77af7a commit 25cc84b

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4086,12 +4086,54 @@ def trim_zeros_tol(filt, tol, trim='fb'):
40864086
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
40874087
return filt[start:len(filt) - end]
40884088

4089-
4090-
@util.implements(np.append)
40914089
@partial(jit, static_argnames=('axis',))
40924090
def append(
40934091
arr: ArrayLike, values: ArrayLike, axis: int | None = None
40944092
) -> Array:
4093+
"""Return a new array with values appended to the end of the original array.
4094+
4095+
JAX implementation of :func:`numpy.append`.
4096+
4097+
Args:
4098+
arr: original array.
4099+
values: values to be appended to the array. The ``values`` must have
4100+
the same number of dimensions as ``arr``, and all dimensions must
4101+
match except in the specified axis.
4102+
axis: axis along which to append values. If None (default), both ``arr``
4103+
and ``values`` will be flattened before appending.
4104+
4105+
Returns:
4106+
A new array with values appended to ``arr``.
4107+
4108+
See also:
4109+
- :func:`jax.numpy.insert`
4110+
- :func:`jax.numpy.delete`
4111+
4112+
Examples:
4113+
>>> a = jnp.array([1, 2, 3])
4114+
>>> b = jnp.array([4, 5, 6])
4115+
>>> jnp.append(a, b)
4116+
Array([1, 2, 3, 4, 5, 6], dtype=int32)
4117+
4118+
Appending along a specific axis:
4119+
4120+
>>> a = jnp.array([[1, 2],
4121+
... [3, 4]])
4122+
>>> b = jnp.array([[5, 6]])
4123+
>>> jnp.append(a, b, axis=0)
4124+
Array([[1, 2],
4125+
[3, 4],
4126+
[5, 6]], dtype=int32)
4127+
4128+
Appending along a trailing axis:
4129+
4130+
>>> a = jnp.array([[1, 2, 3],
4131+
... [4, 5, 6]])
4132+
>>> b = jnp.array([[7], [8]])
4133+
>>> jnp.append(a, b, axis=1)
4134+
Array([[1, 2, 3, 7],
4135+
[4, 5, 6, 8]], dtype=int32)
4136+
"""
40954137
if axis is None:
40964138
return concatenate([ravel(arr), ravel(values)], 0)
40974139
else:

0 commit comments

Comments
 (0)