-
Hello, I'm working on a function and one of the steps is to randomly select an element from an array that meets a certain condition and return the index of the element. If the function is not decorated with
But if the function is decorated with Is there an alternative? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can use the import jax
import jax.numpy as jnp
def get_eligible_cells(arr, step):
key = jax.random.PRNGKey(step)
mask = (arr>2)
matched_inds = jnp.where(jnp.reshape(mask, (-1)), size=arr.size)[0]
idx = jax.random.randint(key, shape=(), minval=0, maxval=mask.sum())
return matched_inds[idx]
array_1d = jnp.array(([1, 2, 3], [4, 5, 6]))
ind = jax.jit(get_eligible_cells)(array_1d, 0)
print(ind)
# 4 |
Beta Was this translation helpful? Give feedback.
You can use the
size
argument tojnp.where
to keep the size of the arrays static, and then userandint
to index into the valid elements: