Introduce configuration flags for gradient checkpointing #1564
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Replaces #1168 . I opened a fresh PR because resolving conflicts in the original became messy. The change set is the same intent; this PR is the clean continuation.
Issue Number
Closes #1141
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh integration-testlaunch-slurm.py --time 60Performance gain
Regarding the performance gain: when we disable all gradient checkpointing on a single node using the default config, we see about a 20% improvement with no CUDA out-of-memory (OOM) issues, compared to the develop branch (commit df59a9a). Here’s how I ran the code on both the develop branch and this branch:
../WeatherGenerator-private/hpc/launch-slurm.py --time 180 --nodes=1The experiments have been uploaded:
Developbranch test ID:cka04xbwjavad/dev/cond_checkpoint_all-1141-ver1branch test ID:s8p7iec0If a different configuration runs into CUDA OOM, you can set some of the flags (not all of them) with the
checkpoint_enabledoption to True. This enables gradient checkpointing so activations are recomputed during the backward pass, which helps avoid CUDA OOM errors.