Skip to content

Commit e18824c

Browse files
authored
fix: ptl 2.6.0 explicitly pass weights_only=False (#710)
## Description [this](Lightning-AI/pytorch-lightning#21072) change in ptl 2.6.0 means we have to explicitly specify "weight_only=False" when calling `BaseGraphModule.load_from_checkpoint` (nice spot Ana!) ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent f7d5ae4 commit e18824c

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

training/src/anemoi/training/train/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def model(self) -> pl.LightningModule:
208208
# pop data_indices so that the data indices on the checkpoint do not get overwritten
209209
# by the data indices from the new config
210210
kwargs.pop("data_indices")
211-
model = model_task.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
211+
model = model_task.load_from_checkpoint(
212+
self.last_checkpoint,
213+
**kwargs,
214+
strict=False,
215+
weights_only=False,
216+
)
212217

213218
model.data_indices = self.data_indices
214219
# check data indices in original checkpoint and current data indices are the same
@@ -436,6 +441,7 @@ def _check_dry_run(self) -> None:
436441
LOGGER.info("Dry run: %s", self.dry_run)
437442

438443
def prepare_compilation(self) -> None:
444+
439445
if hasattr(self.config.model, "compile"):
440446
self.model = mark_for_compilation(self.model, self.config.model_dump(by_alias=True).model.compile)
441447
if hasattr(self.config.training, "recompile_limit"):

training/src/anemoi/training/utils/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Mod
4444
pytorch model, metadata
4545
4646
"""
47-
module = BaseGraphModule.load_from_checkpoint(lightning_checkpoint_path)
47+
module = BaseGraphModule.load_from_checkpoint(lightning_checkpoint_path, weights_only=False)
4848
model = module.model
4949

5050
metadata = dict(**model.metadata)

training/tests/unit/diagnostics/test_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_same_uuid(tmp_path: str, callback: AnemoiCheckpoint, model: DummyModule
134134
if Path(tmp_path + "/" + pl_ckpt_name).exists():
135135
uuid = load_metadata(ckpt_path)["uuid"]
136136

137-
pl_model = DummyModule.load_from_checkpoint(tmp_path + "/" + pl_ckpt_name)
137+
pl_model = DummyModule.load_from_checkpoint(tmp_path + "/" + pl_ckpt_name, weights_only=False)
138138

139139
assert uuid == pl_model.hparams["metadata"]["uuid"]
140140

0 commit comments

Comments
 (0)