Skip to content

Commit 646e01b

Browse files
committed
Fix ModelParallel single-file checkpoint with compiled modules
1 parent d4e476f commit 646e01b

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9+
### Fixed
10+
11+
- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
12+
13+
---
14+
915
## [2.6.0] - 2025-11-21
1016

1117
### Added

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:
286286

287287
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
288288
if not self._save_distributed_checkpoint and self.global_rank == 0:
289-
# Store the optimizer state dict in standard format
289+
state_dict = _align_compiled_param_names_with_module(state_dict, self.model)
290290
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
291291
return state_dict
292292

@@ -366,3 +366,55 @@ def set_world_ranks(self) -> None:
366366
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
367367
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
368368
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
369+
370+
371+
def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]:
372+
"""Align optimizer state dict keys with a module that may have compiled submodules.
373+
374+
When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``.
375+
For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer
376+
state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod``
377+
prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from
378+
``module.named_parameters()``.
379+
380+
This function inserts ``._orig_mod`` into the state dict keys where necessary so that
381+
they match the module's ``named_parameters()`` output.
382+
383+
"""
384+
from torch._dynamo import OptimizedModule
385+
386+
# Build set of compiled submodule prefixes (e.g., "model" if model is compiled)
387+
compiled_prefixes: list[str] = []
388+
for name, submodule in module.named_modules():
389+
if isinstance(submodule, OptimizedModule):
390+
compiled_prefixes.append(name)
391+
392+
if not compiled_prefixes:
393+
return state_dict
394+
395+
# Sort by length descending so longer prefixes are matched first
396+
compiled_prefixes.sort(key=len, reverse=True)
397+
398+
def _transform_key(key: str) -> str:
399+
for prefix in compiled_prefixes:
400+
# Check if key starts with "prefix." (the compiled module path)
401+
if key == prefix or key.startswith(prefix + "."):
402+
suffix = key[len(prefix) :] # e.g., ".0.weight" or ""
403+
# Insert _orig_mod between prefix and rest
404+
return f"{prefix}._orig_mod{suffix}"
405+
return key
406+
407+
# Transform keys in "state" section of the optimizer state dict
408+
if "state" in state_dict:
409+
new_state = {_transform_key(k): v for k, v in state_dict["state"].items()}
410+
state_dict = {**state_dict, "state": new_state}
411+
412+
# Transform param names in "param_groups" section
413+
if "param_groups" in state_dict:
414+
new_param_groups = []
415+
for group in state_dict["param_groups"]:
416+
new_group = {**group, "params": [_transform_key(p) for p in group["params"]]}
417+
new_param_groups.append(new_group)
418+
state_dict = {**state_dict, "param_groups": new_param_groups}
419+
420+
return state_dict

tests/tests_pytorch/strategies/test_model_parallel.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,78 @@ def configure_model(self) -> None:
251251
strategy.setup(Mock())
252252
assert all(not p.is_meta for p in model.parameters())
253253
assert all(not b.is_meta for b in model.buffers())
254+
255+
256+
@RunIf(min_torch="2.4")
257+
def test_align_compiled_param_names_with_module():
258+
"""Test that optimizer state dict keys are aligned with compiled submodule parameter names."""
259+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
260+
261+
class SimpleModule(nn.Module):
262+
def __init__(self):
263+
super().__init__()
264+
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
265+
266+
def forward(self, x):
267+
return self.model(x)
268+
269+
# Test with compiled submodule
270+
m = SimpleModule()
271+
m.model = torch.compile(m.model)
272+
273+
# Simulate optimizer state dict without _orig_mod in keys (includes both state and param_groups)
274+
state_dict = {
275+
"state": {
276+
"model.0.weight": {"step": 1},
277+
"model.0.bias": {"step": 1},
278+
"model.2.weight": {"step": 1},
279+
"model.2.bias": {"step": 1},
280+
},
281+
"param_groups": [{"params": ["model.0.weight", "model.0.bias", "model.2.weight", "model.2.bias"], "lr": 0.01}],
282+
}
283+
284+
result = _align_compiled_param_names_with_module(state_dict, m)
285+
286+
# Verify state keys now have _orig_mod inserted
287+
expected_keys = {
288+
"model._orig_mod.0.weight",
289+
"model._orig_mod.0.bias",
290+
"model._orig_mod.2.weight",
291+
"model._orig_mod.2.bias",
292+
}
293+
assert set(result["state"].keys()) == expected_keys
294+
295+
# Verify param_groups params also have _orig_mod inserted
296+
assert set(result["param_groups"][0]["params"]) == expected_keys
297+
298+
# Verify they match the module's named_parameters
299+
param_names = {name for name, _ in m.named_parameters()}
300+
assert set(result["state"].keys()) == param_names
301+
302+
303+
@RunIf(min_torch="2.4")
304+
def test_align_compiled_param_names_no_compile():
305+
"""Test that non-compiled modules pass through unchanged."""
306+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
307+
308+
class SimpleModule(nn.Module):
309+
def __init__(self):
310+
super().__init__()
311+
self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32))
312+
313+
def forward(self, x):
314+
return self.model(x)
315+
316+
m = SimpleModule() # Not compiled
317+
318+
state_dict = {
319+
"state": {
320+
"model.0.weight": {"step": 1},
321+
"model.0.bias": {"step": 1},
322+
}
323+
}
324+
325+
result = _align_compiled_param_names_with_module(state_dict, m)
326+
327+
# Keys should be unchanged
328+
assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"}

0 commit comments

Comments
 (0)