@@ -3955,12 +3955,54 @@ def trim_zeros_tol(filt, tol, trim='fb'):
3955
3955
end = argmin (nz [::- 1 ]) if 'b' in trim .lower () else 0
3956
3956
return filt [start :len (filt ) - end ]
3957
3957
3958
-
3959
- @util .implements (np .append )
3960
3958
@partial (jit , static_argnames = ('axis' ,))
3961
3959
def append (
3962
3960
arr : ArrayLike , values : ArrayLike , axis : int | None = None
3963
3961
) -> Array :
3962
+ """Return a new array with values appended to the end of the original array.
3963
+
3964
+ JAX implementation of :func:`numpy.append`.
3965
+
3966
+ Args:
3967
+ arr: original array.
3968
+ values: values to be appended to the array. The ``values`` must have
3969
+ the same number of dimensions as ``arr``, and all dimensions must
3970
+ match except in the specified axis.
3971
+ axis: axis along which to append values. If None (default), both ``arr``
3972
+ and ``values`` will be flattened before appending.
3973
+
3974
+ Returns:
3975
+ A new array with values appended to ``arr``.
3976
+
3977
+ See also:
3978
+ - :func:`jax.numpy.insert`
3979
+ - :func:`jax.numpy.delete`
3980
+
3981
+ Examples:
3982
+ >>> a = jnp.array([1, 2, 3])
3983
+ >>> b = jnp.array([4, 5, 6])
3984
+ >>> jnp.append(a, b)
3985
+ Array([1, 2, 3, 4, 5, 6], dtype=int32)
3986
+
3987
+ Appending along a specific axis:
3988
+
3989
+ >>> a = jnp.array([[1, 2],
3990
+ ... [3, 4]])
3991
+ >>> b = jnp.array([[5, 6]])
3992
+ >>> jnp.append(a, b, axis=0)
3993
+ Array([[1, 2],
3994
+ [3, 4],
3995
+ [5, 6]], dtype=int32)
3996
+
3997
+ Appending along a trailing axis:
3998
+
3999
+ >>> a = jnp.array([[1, 2, 3],
4000
+ ... [4, 5, 6]])
4001
+ >>> b = jnp.array([[7], [8]])
4002
+ >>> jnp.append(a, b, axis=1)
4003
+ Array([[1, 2, 3, 7],
4004
+ [4, 5, 6, 8]], dtype=int32)
4005
+ """
3964
4006
if axis is None :
3965
4007
return concatenate ([ravel (arr ), ravel (values )], 0 )
3966
4008
else :
0 commit comments