Skip to content
Discussion options

You must be logged in to vote

If I'm understanding correctly, I think something like this should work:

import jax.numpy as jnp
from jax import lax

indices = jnp.array([[0, 1, 2],
                     [1, 1, 2],
                     [1, 0, 1]])
data = jnp.array([[1.0, 1.0, 3.0],
                  [1.0, 2.0, 3.0],
                  [0.0, 2.0, 3.0]])

def loop(i, data):
  r = jnp.arange(indices.shape[1])
  ind = indices[:, i]
  return jnp.where(r > i, data, data[ind])
result = lax.fori_loop(0, indices.shape[1], loop, data)

print(result)
# [[1. 1. 3.]
#  [1. 1. 3.]
#  [1. 2. 3.]]

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@femtomc
Comment options

@femtomc
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants