Skip to content
Discussion options

You must be logged in to vote

Hi! Vanilla reverse-mode autodiff (jax.grad default) need to store all intermediate values, which causes the OOM.
You can leverage jax.checkpoint to reduce memory consumption (at the cost of extra computation).
How to: (for the loop in the final_function , not the loop over batches.)
(Actually it can work for the loop over batches, but there is a more efficient way.)

  1. divide N steps scan into sqrt(N) chunks of sqrt(N) steps.
  2. wrap each chunk with jax.checkpoint (actually can be implemented with a nested scan, and checkpoint the inner scan).
  3. Thus JAX will only store the inputs of each chunk, and recompute other intermediate values in each chunk when back-prop through this chunk.
  4. Thus the pe…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@noveens
Comment options

@YouJiacheng
Comment options

@noveens
Comment options

Answer selected by noveens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants