Returning normal function after jitting #19451
-
For example, in tensorflow, I can do this after applying @tf.function(jit_compile=True)
def some_fn():
# Adding some implementation details
pass
some_fn.python_function() # Would run the normal function Is there an equivalent in JAX, or do we have to use |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
Also, in case |
Beta Was this translation helpful? Give feedback.
-
JIT-compiled functions do store a reference to their original function in the @jax.jit
def f(x):
return x * 2
print(f(1))
# 2
print(f._fun(1))
# 2 That said, this is not a public API and I wouldn't recommend relying on this, because as a private attribute it may be removed in a future release without warning. Instead, if you need a reference to the original function, you can keep a reference to it at the source: def f(x):
return x * 2
f_jit = jax.jit(f)
print(f(1))
# 2
print(f_jit(1))
# 2 |
Beta Was this translation helpful? Give feedback.
-
Understandable. However, do you guys plan on moving the feature to the public api (like TensorFlow has)? I feel that it might be useful (in cases where we can't keep the reference), though I guess that might be a constraint that only I am facing. |
Beta Was this translation helpful? Give feedback.
JIT-compiled functions do store a reference to their original function in the
_fun
attribute:That said, this is not a public API and I wouldn't recommend relying on this, because as a private attribute it may be removed in a future release without warning.
Instead, if you need a reference to the original function, you can keep a reference to it at the source: