Skip to content
Discussion options

You must be logged in to vote

dynamic_slice_update can only be used for operations where the slice size is static (but the start location is dynamic); for example:

from jax import lax
import jax.numpy as jnp
x = jnp.zeros(10)
y = jnp.ones(3)
lax.dynamic_update_slice(x, y, (3,))
# DeviceArray([0., 0., 0., 1., 1., 1., 0., 0., 0., 0.], dtype=float32)

One way to achieve the operation you have in mind is using jnp.where:

ind = jnp.arange(where.shape[0])
arr = jnp.where((ind > i) & (ind < k), arr[i], arr)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@AdrienCorenflos
Comment options

@jakevdp
Comment options

Answer selected by jakevdp
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants