Skip to content
Discussion options

You must be logged in to vote

I suspect the most convenient way to do this will be using ndarray.at with vmap:

import jax
import jax.numpy as jnp

source = jnp.array([[77, 82, 29, 64, 62, 33, 97], [46, 50,  9, 54, 66, 58, 82]])
target = jnp.array([[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
indices = jnp.array([[0, 2], [1, 4]])

@jax.vmap
def f(source, target, indices):
  return target.at[indices].set(source[:len(indices)])

print(f(source, target, indices))
# [[77  0 82  0  0  0  0]
#  [ 0 46  0  0 50  0  0]]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AakashKumarNain
Comment options

Answer selected by AakashKumarNain
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants