Skip to content
Discussion options

You must be logged in to vote

There is no way to make random.choice execute in a non-traced manner within a jit-compiled function. That said, depending on what your actual application is (I suspect you're not actually concerned with computing the length of a string repr of a traced variable) it may be possible to do what you want to do using the new jax.pure_callback function, which requires JAX version 0.3.17 (not yet released as I write this) Make sure jax is up to date:

$ pip install jax>=0.3.17

... and then you can do something like this:

import jax
import numpy as np

def get_str_len(x):
  return np.array(len(str(x)), dtype='int32')

@jax.jit
def fun(key):
    a = jax.random.choice(key, 5)
    out_type = jax.Sha…

Replies: 8 comments 5 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by gianlucadetommaso
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
1 reply
@sharadmv
Comment options

sharadmv Sep 2, 2022
Collaborator

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants