Skip to content
Discussion options

You must be logged in to vote

JIT-compiled functions do store a reference to their original function in the _fun attribute:

@jax.jit
def f(x):
  return x * 2

print(f(1))
# 2
print(f._fun(1))
# 2

That said, this is not a public API and I wouldn't recommend relying on this, because as a private attribute it may be removed in a future release without warning.

Instead, if you need a reference to the original function, you can keep a reference to it at the source:

def f(x):
  return x * 2

f_jit = jax.jit(f)
print(f(1))
# 2
print(f_jit(1))
# 2

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by Impure-King
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants