Replies: 1 comment 1 reply
-
Hey @jjyyxx, I don't know the exact technical details but I've also experienced increased memory and slower runtime when using |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
My model cannot fit in GPU memory after applying a "layer dropout" trick.
Inside a haiku module, before:
after
this can be workarounded by changing to unconditionally compute
layer(layer_type, x)
But I wonder why jax needs to use more GPU memory (more than 1.6 times) for variant 1?
Beta Was this translation helpful? Give feedback.
All reactions