Skip to content
Discussion options

You must be logged in to vote

You can use jax.disable_jit() as a context manager to run non-jitted versions of all code:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print(x)
  return x

x = jnp.arange(5)

f(x)
# prints Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>

with jax.disable_jit():
  f(x)
  # prints [0 1 2 3 4]

As for calling non-jitted code within jitted code, it's not so easy: that would require on-device XLA code calling back to the Python runtime. You may be able to do something like this using host_callback, but I don't know of any example of exactly the scenario you have in mind.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@marius311
Comment options

@marius311
Comment options

Answer selected by marius311
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants