Skip to content

Conversation

@javak87
Copy link
Contributor

@javak87 javak87 commented Oct 28, 2025

Description

This is a merge branch related to the following PRs:
#1151
#1152
#1153
#1155
#1156

Issue Number

Closes #1141

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Comparing the baseline and current PR performance

When pred_gradient_checkpoint_mode is set to false, the performance and GPU memory peak are as follows:

default_config.yml:

../WeatherGenerator-private/hpc/launch-slurm.py --time 60
all_checkpoint_off

mixed.yml:

../WeatherGenerator-private/hpc/launch-slurm.py --time 60 --config ./config/mixed.yml

all_checkpoint_off_mixed

@javak87 javak87 marked this pull request as draft October 28, 2025 10:16
@javak87 javak87 changed the title Test turning all grad checkpointing off Test turning all grad checkpointings off Oct 28, 2025
Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@javak87 thanks for the assembly of this PR. I suggest a way to abstract all the checkpoint changes, happy to talk about it if unclear.


# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))
if self.embed_gradient_checkpoint_mode:
Copy link
Collaborator

@tjhunter tjhunter Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right place for branching, because you are then forced to copy the same logic across each branch.

Here is what I would suggest: we write a small conditional checkpoint function in one of the utility files (the signature is based on looking at the pytorch checkpoint function):

    def cond_checkpoint(
            enable_checkpoint: bool,
            function,
            *args,
            use_reentrant: Optional[bool] = None,
            context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
            determinism_check: str = _DEFAULT_DETERMINISM_MODE,
            debug: bool = False,
            **kwargs
    ):
        if enable_checkpoint:
            checkpoint(function, ...)
        else:
            function(*args, **kwargs)

and then, the only change required is to convert:

x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))

into:

x = peh(cond_checkpoint(self.embed_gradient_checkpoint_mode, self.embed, x_in.transpose(-2, -1), use_reentrant=False))

How does that sound?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, we can also build this rule in the linter: do not use checkpoint directly but use cond_checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m okay with this helper function. However, I should put it in __init__ for every checkpoint, because having these conditional helper functions in forward is degrading performance.

cell_lens_c,
use_reentrant=False,
)
if self.cf.ae_adapter_grdient_checkpoint_mode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always assume the parameter may be missing:

self.cf.get("ae_adapter_grdient_checkpoint_mode", True)

@github-project-automation github-project-automation bot moved this to In Progress in WeatherGen-dev Dec 3, 2025
@javak87
Copy link
Contributor Author

javak87 commented Jan 8, 2026

Because of multiple conflicts, I opened a new PR.
New PR #1564

@javak87 javak87 closed this Jan 8, 2026
@github-project-automation github-project-automation bot moved this from In Progress to Done in WeatherGen-dev Jan 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Balance Memory and Compute in Gradient Checkpointing

2 participants