Replies: 1 comment 1 reply
-
You can use the with jax.disable_jit():
do_something_without_jit() Or to disable JIT globally, you can use the from jax import config
config.update('jax_disable_jit', True) More at https://jax.readthedocs.io/en/latest/jax.html#jax.disable_jit |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Python lends itself well to an interactive style of programming:
This approach works well when the code can be executed quickly and overall iterations short. I use it all the time for new designs, getting things right with the first set of inputs.
However, with JAX omnistaging, it seems that most JAX/FLAX code is compiled to XLA (regardless of whether JIT is disabled or not), creating lengthy delays and making this interactive style painfully slow.
Is there a way to speed up, or avoid entirely, compiling to XLA?
I have heard, in theory, replacing jax.numpy calls with numpy is a workable solution-- most of an algorithm/FLAX module can be rapidly worked out using numpy and then converted to JAX. However, in practice, I know of no way of doing this except tediously, call by call.
Appreciate any thoughts and ideas on this.
Best to all
Beta Was this translation helpful? Give feedback.
All reactions