Skip to content

Commit 42a9917

Browse files
committed
Fix ModelParallel single-file checkpoint with compiled modules
1 parent d4e476f commit 42a9917

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9+
### Fixed
10+
11+
- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
12+
13+
---
14+
915
## [2.6.0] - 2025-11-21
1016

1117
### Added

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,18 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:
286286

287287
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
288288
if not self._save_distributed_checkpoint and self.global_rank == 0:
289-
# Store the optimizer state dict in standard format
290-
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
289+
# ``torch.compile`` wraps the module, so state dict keys are prefixed with ``_orig_mod.``.
290+
# Rekey on the wrapped module first, then rekey again on the original module so parameter
291+
# names match what the Trainer expects when saving a single-file checkpoint.
292+
compiled_model = self.model
293+
original_model = getattr(compiled_model, "_orig_mod", None)
294+
if original_model is not None:
295+
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, compiled_model)
296+
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_NAME, compiled_model)
297+
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_NAME, original_model)
298+
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, original_model)
299+
else:
300+
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, compiled_model)
291301
return state_dict
292302

293303
@override

0 commit comments

Comments
 (0)