-
The situation I'm in is I'm writing some library code which I want JIT-ed, but at some point I call a user provided function which may not be JIT-able. So ideally I want to write something like: @jax.dontjit
def user_provided_unjittable_function(x):
return np.exp(x)
@jax.jit
def my_library_function(x):
# some stuff here I want jitted
return user_provided_unjittable_function(x)
my_library_function(2) where the user has indicated not to JIT their function with a hypothetical I realize you can refactor def my_library_function(x):
jax.jit(_jitted_parts)(...) # some stuff here I want jitted
return user_provided_unjittable_function(x) but in my real example it'd be much more complex and make the code less clear to do so. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You can use import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print(x)
return x
x = jnp.arange(5)
f(x)
# prints Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>
with jax.disable_jit():
f(x)
# prints [0 1 2 3 4] As for calling non-jitted code within jitted code, it's not so easy: that would require on-device XLA code calling back to the Python runtime. You may be able to do something like this using host_callback, but I don't know of any example of exactly the scenario you have in mind. |
Beta Was this translation helpful? Give feedback.
You can use
jax.disable_jit()
as a context manager to run non-jitted versions of all code:As for calling non-jitted code within jitted code, it's not so easy: that would require on-device XLA code calling back to the Python runtime. You may be able to do something like this using host_callback, but I don't know of any example of exactly the scenario you have in mind.