dynamic_slice_update example #6547
-
Hi, while_loop...
arr = index_update(arr, index[i + 1: k], arr[i]) JAX is complaining about it due to the need for the program to be static, and is telling me to use |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
One way to achieve the operation you have in mind is using ind = jnp.arange(where.shape[0])
arr = jnp.where((ind > i) & (ind < k), arr[i], arr) |
Beta Was this translation helpful? Give feedback.
-
👍
Le ven. 23 avr. 2021 à 19:12, Jake Vanderplas ***@***.***> a
écrit :
… The docstring lacked any examples, so I've added a few in #6550
<#6550>
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#6547 (reply in thread)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEYGFZ47GXF27T5W6A6HENTTKGL7FANCNFSM43OG2IKQ>
.
|
Beta Was this translation helpful? Give feedback.
dynamic_slice_update
can only be used for operations where the slice size is static (but the start location is dynamic); for example:One way to achieve the operation you have in mind is using
jnp.where
: