Question on slow JAX JIT compilation #11869
Unanswered
dhruvsreenivas
asked this question in
Q&A
Replies: 1 comment 12 replies
-
Thanks for updating the code snippets. From the error, it looks like the issue is you are tracing the
then call this function on your inputs. |
Beta Was this translation helpful? Give feedback.
12 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Hope you are doing well! I am working on tweaking a recent DeepMind algorithm (arXived in late June so official code is not out yet) for a new setting, and have implemented the loss function, but for some reason it takes quite a long time for JAX to JIT compile the code.
Attached is a code snippet that illustrates the loss function that I want to initialize. Basically, I am looping over a sliding window of a sequence of inputs, and pad with 0s appropriately (see comments in code for what I am trying to do). I'm not entirely sure what I am doing wrong in terms of making the code not really jittable.
The helper methods that are called in the above code snippet are in the snippet below. On my machine, the
sliding_window
function is jitted with nofunctools.partial
on any argument and thepad_sliding_windows
method is partially jitted with staticseq_len
parameter (although this may have to be changed due to unhashability of integers).The specific error message is below. My jax/jaxlib version is 0.3.15 and my haiku version is 0.0.7.
Any help would be greatly appreciated--I know that this might potentially be a dumb question with an easy answer but I'd love any help and explanation. I am also willing to explain what I was trying to do in more detail if need be. Thank you so much, I love the framework!
Beta Was this translation helpful? Give feedback.
All reactions