scatter and update #19880
Answered
by
jakevdp
AakashKumarNain
asked this question in
Q&A
scatter and update
#19880
-
I have a jax array of shape
I am not sure how to do this with |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 19, 2024
Replies: 1 comment 1 reply
-
I suspect the most convenient way to do this will be using import jax
import jax.numpy as jnp
source = jnp.array([[77, 82, 29, 64, 62, 33, 97], [46, 50, 9, 54, 66, 58, 82]])
target = jnp.array([[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
indices = jnp.array([[0, 2], [1, 4]])
@jax.vmap
def f(source, target, indices):
return target.at[indices].set(source[:len(indices)])
print(f(source, target, indices))
# [[77 0 82 0 0 0 0]
# [ 0 46 0 0 50 0 0]] |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
AakashKumarNain
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I suspect the most convenient way to do this will be using
ndarray.at
withvmap
: