Get composition structure from Jax composed function #11088
-
Perhaps this is more a Python question than a Jax question but is it possible with For example, I would like to do something along the lines of this (but also possibly nested): x = np.ones(2)
if isinstance(f, jax.vmap):
y = f(x[None, ...])
else:
y = f(x) I tried doing this with import inspect
inspect.getclosurevars(jax.vmap(f))
>> ClosureVars(nonlocals={'fun': <function vmap.<locals>.vmap_f at 0x7fcd940e3af0>}, globals={'filtering_mode': <function filtering_mode at 0x7fcf2028a040>, 'is_under_reraiser': ... It seems like this only reveals the outer-most Is there a more general way to achieve this? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
As far as I am aware, JAX has made no attempt to provide hooks for this level of inspection, rather JAX makes use of decorators/closures, and liberally uses Python's from jax import vmap, jit
def f(x):
return x
f_mapped = vmap(f)
f_jitted = jit(f)
f_mapped_then_jitted = jit(f_mapped)
print(f_mapped.__wrapped__ is f)
# True
print(f_jitted.__wrapped__ is f)
# True
print(f_mapped_then_jitted.__wrapped__ is f_mapped)
# True
print(f_mapped_then_jitted.__wrapped__.__wrapped__ is f)
# True That said, JAX does not add any attributes to tell you which transform has been used on the wrapped function. You may be able to use implementation details to infer this; for example: >>> type(f)
function
>>> type(jax.vmap(f))
function
>>> type(jax.jit(f))
jaxlib.xla_extension.CompiledFunction But these are not considered public parts of the API, so I would not recommend relying on them. I haven't thought about the pros and cons of implementing a user-facing API to track the transformation structure of functions, but you may wish to open a feature request if you have an application that would benefit from it. |
Beta Was this translation helpful? Give feedback.
-
If you are not writing a library, I think you can wrap function transformations( |
Beta Was this translation helpful? Give feedback.
As far as I am aware, JAX has made no attempt to provide hooks for this level of inspection, rather JAX makes use of decorators/closures, and liberally uses Python's
functools.wraps
for this purpose. This means that wrapped functions do have a__wrapped__
attribute that point to the original function:That said, JAX does not add any attributes to tell you which transf…