@@ -8307,6 +8307,9 @@ def delete(
83078307 may specify ``assume_unique_indices=True`` to perform the operation in a
83088308 manner that does not require static indices.
83098309
8310+ See also:
8311+ - :func:`jax.numpy.insert`: insert entries into an array.
8312+
83108313 Examples:
83118314 Delete entries from a 1D array:
83128315
@@ -8400,9 +8403,55 @@ def delete(
84008403 return a [tuple (slice (None ) for i in range (axis )) + (mask ,)]
84018404
84028405
8403- @util .implements (np .insert )
84048406def insert (arr : ArrayLike , obj : ArrayLike | slice , values : ArrayLike ,
84058407 axis : int | None = None ) -> Array :
8408+ """Insert entries into an array at specified indices.
8409+
8410+ JAX implementation of :func:`numpy.insert`.
8411+
8412+ Args:
8413+ arr: array object into which values will be inserted.
8414+ obj: slice or array of indices specifying insertion locations.
8415+ values: array of values to be inserted.
8416+ axis: specify the insertion axis in the case of multi-dimensional
8417+ arrays. If unspecified, ``arr`` will be flattened.
8418+
8419+ Returns:
8420+ A copy of ``arr`` with values inserted at the specified locations.
8421+
8422+ See also:
8423+ - :func:`jax.numpy.delete`: delete entries from an array.
8424+
8425+ Examples:
8426+ Inserting a single value:
8427+
8428+ >>> x = jnp.arange(5)
8429+ >>> jnp.insert(x, 2, 99)
8430+ Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
8431+
8432+ Inserting multiple identical values using a slice:
8433+
8434+ >>> jnp.insert(x, slice(None, None, 2), -1)
8435+ Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
8436+
8437+ Inserting multiple values using an index:
8438+
8439+ >>> indices = jnp.array([4, 2, 5])
8440+ >>> values = jnp.array([10, 11, 12])
8441+ >>> jnp.insert(x, indices, values)
8442+ Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
8443+
8444+ Inserting columns into a 2D array:
8445+
8446+ >>> x = jnp.array([[1, 2, 3],
8447+ ... [4, 5, 6]])
8448+ >>> indices = jnp.array([1, 3])
8449+ >>> values = jnp.array([[10, 11],
8450+ ... [12, 13]])
8451+ >>> jnp.insert(x, indices, values, axis=1)
8452+ Array([[ 1, 10, 2, 3, 11],
8453+ [ 4, 12, 5, 6, 13]], dtype=int32)
8454+ """
84068455 util .check_arraylike ("insert" , arr , 0 if isinstance (obj , slice ) else obj , values )
84078456 a = asarray (arr )
84088457 values_arr = asarray (values )
0 commit comments