Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! This is not possible because JAX operations can only operate on JAX arrays (not lists), and JAX does not support a string dtype. If you had numerical data it would be fine:

jax.vmap(sample_from_list, (0, None), 0)(idxs, jnp.array([1, 2, 3, 4]))
# DeviceArray([2, 1], dtype=int32)

but there is no way to do such computation with lists of strings in JAX.

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@externalsupplierstaff
Comment options

Answer selected by externalsupplierstaff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants