JAX equivalent to TensorFlow tf.gather_nd() #11422
Unanswered
conceptofmind
asked this question in
Q&A
Replies: 2 comments
-
def jax_gather_nd(params, indices):
tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
return params[tuple_indices] In your case: def _gather(x, indices, gather_axis):
all_indices = jnp.full(indices.shape)
gather_locations = jnp.reshape(indices, [indices.shape.num_elements()])
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(gather_locations)
else:
gather_indices.append(all_indices[:, axis])
gathered = x[tuple(gather_indices)]
reshaped = jnp.reshape(gathered, indices.shape)
return reshaped |
Beta Was this translation helpful? Give feedback.
0 replies
-
You could refer also to TFP's implementation of |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Is there an equivalent to
tf.gather_nd()
in Jax?For example, I have a defined a function to gather indices:
I greatly appreciate any help.
Thank you,
Enrico
Beta Was this translation helpful? Give feedback.
All reactions