-
Hi all! I have transformations which I apply to functions -- that can occur after I haven't figured out a good way to do this, and I have many questions. For example:
If I later on attempt to Now, I may be doing something horrible (like, should I even be using |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Good question! There isn't really any public API for this – the object returned by from jax import jit
from functools import partial
@partial(jit, static_argnames=['axis'])
def sum(x, axis=None):
return x.sum(axis)
print(sum.__getstate__())
The returned value is a Python dict with entries related to what you're asking about (Note that these are versions of Be warned that if you use this, you should consider it a non-public API whose details may change without warning from release to release. |
Beta Was this translation helpful? Give feedback.
Good question! There isn't really any public API for this – the object returned by
jax.jit
is defined in C++ and does not directly expose thestatic_argnames
orstatic_argnums
parameter in its Python API – however, these components are part of the object's state used for pickling, so you can access them viaf.__getstate__()
. For example: