How to keep jax.random.choice untraced in a jitted function? #12180
-
I would love to keep the output of
With Without How can I keep the output of Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 5 replies
-
There is no way to make
... and then you can do something like this: import jax
import numpy as np
def get_str_len(x):
return np.array(len(str(x)), dtype='int32')
@jax.jit
def fun(key):
a = jax.random.choice(key, 5)
out_type = jax.ShapeDtypeStruct((), 'int32')
return jax.pure_callback(get_str_len, out_type, a)
print(fun(jax.random.PRNGKey(0)))
# 1 |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot, this sounds like it should work for my application. Any indication when JAX version 0.3.17 might be released? |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot! However I don't find it among the CUDA releases in here: Would it be possible to add it? |
Beta Was this translation helpful? Give feedback.
-
I tried |
Beta Was this translation helpful? Give feedback.
-
The object is a flax.core.TrainState. What should be the output type instead of For more clarity, I'm trying to restore a checkpoint from a random path. The path has the following form: I would solve the problem if I sampled I hope this helps clarifying rather than going too much in detail. |
Beta Was this translation helpful? Give feedback.
-
It worked, thanks! |
Beta Was this translation helpful? Give feedback.
-
Sorry to bring this up again, but I noticed that on GPU this does not work as expected. Although the device should make no difference, the very same piece of code involving the In principle, would there be any reason for |
Beta Was this translation helpful? Give feedback.
-
I found out that I suspect there is a bug on the Flax side of things. Maybe you have an intuition where the problem comes from? Thank you in advance! |
Beta Was this translation helpful? Give feedback.
There is no way to make
random.choice
execute in a non-traced manner within a jit-compiled function. That said, depending on what your actual application is (I suspect you're not actually concerned with computing the length of a string repr of a traced variable) it may be possible to do what you want to do using the newjax.pure_callback
function, which requires JAX version 0.3.17(not yet released as I write this)Make sure jax is up to date:... and then you can do something like this: