Skip to content
Discussion options

You must be logged in to vote

Try:

In [5]: import jax

In [6]: f = jax.jit(lambda x: x * 2)

In [7]: %timeit f.lower(42).compile()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
7.54 ms ± 231 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

.lower() is new and not yet documented, but it does exactly what you want.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@phc27x
Comment options

Answer selected by phc27x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants