Skip to content
Discussion options

You must be logged in to vote

You can use the size argument to jnp.where to keep the size of the arrays static, and then use randint to index into the valid elements:

import jax
import jax.numpy as jnp

def get_eligible_cells(arr, step):
    key = jax.random.PRNGKey(step)
    mask = (arr>2)
    matched_inds = jnp.where(jnp.reshape(mask, (-1)), size=arr.size)[0]
    idx = jax.random.randint(key, shape=(), minval=0, maxval=mask.sum())
    return matched_inds[idx]

array_1d = jnp.array(([1, 2, 3], [4, 5, 6]))
ind = jax.jit(get_eligible_cells)(array_1d, 0)
print(ind)
# 4

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@ijku
Comment options

Answer selected by ijku
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants