Replies: 2 comments
-
Alternatively, I guess what I am asking is how to release the memory occupied by the graph to compute I find in jax, the forward part of |
Beta Was this translation helpful? Give feedback.
-
Hey @zw615, there are 3 options of how to stop gradients in JAX:
For more info checkout Flax's Transfer Learning guide. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there! I want to turn off gradient computations when a model forwards some input, which is important in some cases like accumulative gradient implementation when memory is limited. I have searched the issues and discussion panels and find this post: #1937, which talks about
jax.lax.stop_gradient
. However, I find the code below only disables the gradient flow through the opjax.lax.stop_gradient
, but still performs computational graph building/tracing. As a result, the accumulative gradient technique does not save memory at all. I wonder how I can extract features without any gradient operation, just like inference undertorch.no_grad
?Thanks a lot!
Note that this accumulative gradient implementation is different from the one commonly used in supervised training like here https://github.com/google-research/big_vision/blob/47ac2fd075fcb66cadc0e39bd959c78a6080070d/big_vision/utils.py#L296. This implementation is useful in contrastive learning like CLIP.
Beta Was this translation helpful? Give feedback.
All reactions