help with this small code: Can I sample from a list inside vmap? #7594
Answered
by
jakevdp
externalsupplierstaff
asked this question in
Q&A
-
What is the correct way to do the following? idxs = jnp.array([1, 0])
def sample_from_list(idx, mylist):
return mylist[idx]
print(idxs)
print(sample_from_list(idxs[0], list('abcd')))
jax.vmap(sample_from_list, (0, None), 0)(idxs, list('abcd')) |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Aug 11, 2021
Replies: 2 comments 1 reply
-
my workaround: def my_vmap_first_arg(fn, first_arg, *args):
ress = [fn(first_arg_e, *args) for first_arg_e in first_arg]
return tuple(zip(*ress))[0] because the original code fails with the following output:
|
Beta Was this translation helpful? Give feedback.
0 replies
-
Hi - thanks for the question! This is not possible because JAX operations can only operate on JAX arrays (not lists), and JAX does not support a string dtype. If you had numerical data it would be fine: jax.vmap(sample_from_list, (0, None), 0)(idxs, jnp.array([1, 2, 3, 4]))
# DeviceArray([2, 1], dtype=int32) but there is no way to do such computation with lists of strings in JAX. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
externalsupplierstaff
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi - thanks for the question! This is not possible because JAX operations can only operate on JAX arrays (not lists), and JAX does not support a string dtype. If you had numerical data it would be fine:
but there is no way to do such computation with lists of strings in JAX.