-
Notifications
You must be signed in to change notification settings - Fork 51
Test turning all grad checkpointings off #1168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test turning all grad checkpointings off #1168
Conversation
…into javad/dev/cond_checkpoint_embed_transformer
…/cond_checkpoint_all-1141
…javad/dev/cond_checkpoint_all-1141
…javad/dev/cond_checkpoint_all-1141
…cond_checkpoint_all-1141
tjhunter
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@javak87 thanks for the assembly of this PR. I suggest a way to abstract all the checkpoint changes, happy to talk about it if unclear.
|
|
||
| # embed provided input data | ||
| x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)) | ||
| if self.embed_gradient_checkpoint_mode: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the right place for branching, because you are then forced to copy the same logic across each branch.
Here is what I would suggest: we write a small conditional checkpoint function in one of the utility files (the signature is based on looking at the pytorch checkpoint function):
def cond_checkpoint(
enable_checkpoint: bool,
function,
*args,
use_reentrant: Optional[bool] = None,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
debug: bool = False,
**kwargs
):
if enable_checkpoint:
checkpoint(function, ...)
else:
function(*args, **kwargs)and then, the only change required is to convert:
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))into:
x = peh(cond_checkpoint(self.embed_gradient_checkpoint_mode, self.embed, x_in.transpose(-2, -1), use_reentrant=False))How does that sound?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way, we can also build this rule in the linter: do not use checkpoint directly but use cond_checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m okay with this helper function. However, I should put it in __init__ for every checkpoint, because having these conditional helper functions in forward is degrading performance.
| cell_lens_c, | ||
| use_reentrant=False, | ||
| ) | ||
| if self.cf.ae_adapter_grdient_checkpoint_mode: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always assume the parameter may be missing:
self.cf.get("ae_adapter_grdient_checkpoint_mode", True)|
Because of multiple conflicts, I opened a new PR. |
Description
This is a merge branch related to the following PRs:
#1151
#1152
#1153
#1155
#1156
Issue Number
Closes #1141
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60Comparing the baseline and current PR performance
When
pred_gradient_checkpoint_modeis set to false, the performance and GPU memory peak are as follows:default_config.yml:
../WeatherGenerator-private/hpc/launch-slurm.py --time 60mixed.yml:
../WeatherGenerator-private/hpc/launch-slurm.py --time 60 --config ./config/mixed.yml