Interaction of jax.checkpoint and jax.jit #6228
Unanswered
davisyoshida
asked this question in
Q&A
Replies: 1 comment 4 replies
-
If this is inside an https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html While I'm here, I'd recommend looking into |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've ported GPT-2 to JAX+Haiku, and found that I was able to do Adam steps with a batch size of 1 with a 382 token context size. I was hoping to bump this up higher by using
jax.checkpoint
on each layer self-attention block, but surprisingly I was getting OOM errors with the same context size which ran successfully without checkpointing. I'm guessing there's something going on with the interaction ofjit
andcheckpoint
, but I don't have any idea how to investigate this further. Any ideas on how I can prod at this to figure out what's happening?Apologies for the lack of minimal example, but since I'm not sure what's going on, I couldn't figure out how to reproduce it in a small example.
Beta Was this translation helpful? Give feedback.
All reactions