Skip to content

Commit 1e41d5e

Browse files
Merge pull request jax-ml#24452 from jakevdp:insert-doc
PiperOrigin-RevId: 688624762
2 parents 1a2737b + 48dd153 commit 1e41d5e

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8307,6 +8307,9 @@ def delete(
83078307
may specify ``assume_unique_indices=True`` to perform the operation in a
83088308
manner that does not require static indices.
83098309
8310+
See also:
8311+
- :func:`jax.numpy.insert`: insert entries into an array.
8312+
83108313
Examples:
83118314
Delete entries from a 1D array:
83128315
@@ -8400,9 +8403,55 @@ def delete(
84008403
return a[tuple(slice(None) for i in range(axis)) + (mask,)]
84018404

84028405

8403-
@util.implements(np.insert)
84048406
def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
84058407
axis: int | None = None) -> Array:
8408+
"""Insert entries into an array at specified indices.
8409+
8410+
JAX implementation of :func:`numpy.insert`.
8411+
8412+
Args:
8413+
arr: array object into which values will be inserted.
8414+
obj: slice or array of indices specifying insertion locations.
8415+
values: array of values to be inserted.
8416+
axis: specify the insertion axis in the case of multi-dimensional
8417+
arrays. If unspecified, ``arr`` will be flattened.
8418+
8419+
Returns:
8420+
A copy of ``arr`` with values inserted at the specified locations.
8421+
8422+
See also:
8423+
- :func:`jax.numpy.delete`: delete entries from an array.
8424+
8425+
Examples:
8426+
Inserting a single value:
8427+
8428+
>>> x = jnp.arange(5)
8429+
>>> jnp.insert(x, 2, 99)
8430+
Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
8431+
8432+
Inserting multiple identical values using a slice:
8433+
8434+
>>> jnp.insert(x, slice(None, None, 2), -1)
8435+
Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
8436+
8437+
Inserting multiple values using an index:
8438+
8439+
>>> indices = jnp.array([4, 2, 5])
8440+
>>> values = jnp.array([10, 11, 12])
8441+
>>> jnp.insert(x, indices, values)
8442+
Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
8443+
8444+
Inserting columns into a 2D array:
8445+
8446+
>>> x = jnp.array([[1, 2, 3],
8447+
... [4, 5, 6]])
8448+
>>> indices = jnp.array([1, 3])
8449+
>>> values = jnp.array([[10, 11],
8450+
... [12, 13]])
8451+
>>> jnp.insert(x, indices, values, axis=1)
8452+
Array([[ 1, 10, 2, 3, 11],
8453+
[ 4, 12, 5, 6, 13]], dtype=int32)
8454+
"""
84068455
util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
84078456
a = asarray(arr)
84088457
values_arr = asarray(values)

0 commit comments

Comments
 (0)