Skip to content

Commit d77af7a

Browse files
committed
append_docstring_added
append_docstring_modified append_doc_line_break append_doc_linting_fixed
1 parent eba0564 commit d77af7a

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,12 +3955,54 @@ def trim_zeros_tol(filt, tol, trim='fb'):
39553955
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
39563956
return filt[start:len(filt) - end]
39573957

3958-
3959-
@util.implements(np.append)
39603958
@partial(jit, static_argnames=('axis',))
39613959
def append(
39623960
arr: ArrayLike, values: ArrayLike, axis: int | None = None
39633961
) -> Array:
3962+
"""Return a new array with values appended to the end of the original array.
3963+
3964+
JAX implementation of :func:`numpy.append`.
3965+
3966+
Args:
3967+
arr: original array.
3968+
values: values to be appended to the array. The ``values`` must have
3969+
the same number of dimensions as ``arr``, and all dimensions must
3970+
match except in the specified axis.
3971+
axis: axis along which to append values. If None (default), both ``arr``
3972+
and ``values`` will be flattened before appending.
3973+
3974+
Returns:
3975+
A new array with values appended to ``arr``.
3976+
3977+
See also:
3978+
- :func:`jax.numpy.insert`
3979+
- :func:`jax.numpy.delete`
3980+
3981+
Examples:
3982+
>>> a = jnp.array([1, 2, 3])
3983+
>>> b = jnp.array([4, 5, 6])
3984+
>>> jnp.append(a, b)
3985+
Array([1, 2, 3, 4, 5, 6], dtype=int32)
3986+
3987+
Appending along a specific axis:
3988+
3989+
>>> a = jnp.array([[1, 2],
3990+
... [3, 4]])
3991+
>>> b = jnp.array([[5, 6]])
3992+
>>> jnp.append(a, b, axis=0)
3993+
Array([[1, 2],
3994+
[3, 4],
3995+
[5, 6]], dtype=int32)
3996+
3997+
Appending along a trailing axis:
3998+
3999+
>>> a = jnp.array([[1, 2, 3],
4000+
... [4, 5, 6]])
4001+
>>> b = jnp.array([[7], [8]])
4002+
>>> jnp.append(a, b, axis=1)
4003+
Array([[1, 2, 3, 7],
4004+
[4, 5, 6, 8]], dtype=int32)
4005+
"""
39644006
if axis is None:
39654007
return concatenate([ravel(arr), ravel(values)], 0)
39664008
else:

0 commit comments

Comments
 (0)