make_jaxpr doesn't work with jitted function having static_argnums #15795
-
I am learning JAX and was interested to see how JAX handles static arguments differently from normal arguments, but the following code doesn't seem to work. Am I missing something obvious or is this a bug?
Error : |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, thanks for the question! Unlike from functools import partial
make_jaxpr(partial(f, y=y))(x) |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question! Unlike
jit
,make_jaxpr
does not have any mechanism to support static arguments directly, so the best way to proceed is to close over the static arguments, for example usingpartial
: