Jitted function: output tensor shape depends on input tensor values #8492
Unanswered
refraction-ray
asked this question in
Q&A
Replies: 1 comment
-
Yes, you've read the documentation correctly: jax transforms (including JIT) require arrays to be statically shaped. There is some ongoing experimentation with dynamic shapes, but it is not yet part of the package. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I know well that "output tensor shape depends on input tensor values" paradigm is forbidden in jax jitted function (mentioned everywhere in issues and docs), but still curious on whether there is a workaround (not the
static_argnums
one, since the output shapes can vary a lot which makes the function retrace each time ).After all, tensorflow can do this paradigm fine even with
jit_compile=True
enabled, i.e. compiling to XLA. See the demo below. If there is currently no satisfying workaround in jax, any future plan on such a feature?Beta Was this translation helpful? Give feedback.
All reactions