Skip to content
Discussion options

You must be logged in to vote

Good question! There isn't really any public API for this – the object returned by jax.jit is defined in C++ and does not directly expose the static_argnames or static_argnums parameter in its Python API – however, these components are part of the object's state used for pickling, so you can access them via f.__getstate__(). For example:

from jax import jit
from functools import partial

@partial(jit, static_argnames=['axis'])
def sum(x, axis=None):
  return x.sum(axis)

print(sum.__getstate__())
{'version': 1,
 'fun': <function __main__.sum(x, axis=None)>,
 'cache_miss': <function jax._src.api._cpp_jit.<locals>.cache_miss(*args, **kwargs)>,
 'get_device': <function jax._src.api._cpp_jit.…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants