@@ -4086,12 +4086,54 @@ def trim_zeros_tol(filt, tol, trim='fb'):
4086
4086
end = argmin (nz [::- 1 ]) if 'b' in trim .lower () else 0
4087
4087
return filt [start :len (filt ) - end ]
4088
4088
4089
-
4090
- @util .implements (np .append )
4091
4089
@partial (jit , static_argnames = ('axis' ,))
4092
4090
def append (
4093
4091
arr : ArrayLike , values : ArrayLike , axis : int | None = None
4094
4092
) -> Array :
4093
+ """Return a new array with values appended to the end of the original array.
4094
+
4095
+ JAX implementation of :func:`numpy.append`.
4096
+
4097
+ Args:
4098
+ arr: original array.
4099
+ values: values to be appended to the array. The ``values`` must have
4100
+ the same number of dimensions as ``arr``, and all dimensions must
4101
+ match except in the specified axis.
4102
+ axis: axis along which to append values. If None (default), both ``arr``
4103
+ and ``values`` will be flattened before appending.
4104
+
4105
+ Returns:
4106
+ A new array with values appended to ``arr``.
4107
+
4108
+ See also:
4109
+ - :func:`jax.numpy.insert`
4110
+ - :func:`jax.numpy.delete`
4111
+
4112
+ Examples:
4113
+ >>> a = jnp.array([1, 2, 3])
4114
+ >>> b = jnp.array([4, 5, 6])
4115
+ >>> jnp.append(a, b)
4116
+ Array([1, 2, 3, 4, 5, 6], dtype=int32)
4117
+
4118
+ Appending along a specific axis:
4119
+
4120
+ >>> a = jnp.array([[1, 2],
4121
+ ... [3, 4]])
4122
+ >>> b = jnp.array([[5, 6]])
4123
+ >>> jnp.append(a, b, axis=0)
4124
+ Array([[1, 2],
4125
+ [3, 4],
4126
+ [5, 6]], dtype=int32)
4127
+
4128
+ Appending along a trailing axis:
4129
+
4130
+ >>> a = jnp.array([[1, 2, 3],
4131
+ ... [4, 5, 6]])
4132
+ >>> b = jnp.array([[7], [8]])
4133
+ >>> jnp.append(a, b, axis=1)
4134
+ Array([[1, 2, 3, 7],
4135
+ [4, 5, 6, 8]], dtype=int32)
4136
+ """
4095
4137
if axis is None :
4096
4138
return concatenate ([ravel (arr ), ravel (values )], 0 )
4097
4139
else :
0 commit comments