Skip to content
Discussion options

You must be logged in to vote

As far as I am aware, JAX has made no attempt to provide hooks for this level of inspection, rather JAX makes use of decorators/closures, and liberally uses Python's functools.wraps for this purpose. This means that wrapped functions do have a __wrapped__ attribute that point to the original function:

from jax import vmap, jit

def f(x):
  return x

f_mapped = vmap(f)
f_jitted = jit(f)
f_mapped_then_jitted = jit(f_mapped)

print(f_mapped.__wrapped__ is f)
# True
print(f_jitted.__wrapped__ is f)
# True
print(f_mapped_then_jitted.__wrapped__ is f_mapped)
# True
print(f_mapped_then_jitted.__wrapped__.__wrapped__ is f)
# True

That said, JAX does not add any attributes to tell you which transf…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@joeryjoery
Comment options

Answer selected by joeryjoery
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants