Skip to content

Commit 31b0976

Browse files
committed
Fix ModelParallel single-file checkpoint with compiled modules
1 parent d4e476f commit 31b0976

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9+
### Fixed
10+
11+
- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))
12+
13+
---
14+
915
## [2.6.0] - 2025-11-21
1016

1117
### Added

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:
286286

287287
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
288288
if not self._save_distributed_checkpoint and self.global_rank == 0:
289-
# Store the optimizer state dict in standard format
289+
state_dict = _align_compiled_param_names_with_module(state_dict, self.model)
290290
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
291291
return state_dict
292292

@@ -366,3 +366,47 @@ def set_world_ranks(self) -> None:
366366
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
367367
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
368368
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
369+
370+
371+
def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]:
372+
"""Align optimizer state dict keys with a module that may have compiled submodules.
373+
374+
When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``.
375+
For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer
376+
state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod``
377+
prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from
378+
``module.named_parameters()``.
379+
380+
This function inserts ``._orig_mod`` into the state dict keys where necessary so that
381+
they match the module's ``named_parameters()`` output.
382+
383+
"""
384+
from torch._dynamo import OptimizedModule
385+
386+
# Build set of compiled submodule prefixes (e.g., "model" if model is compiled)
387+
compiled_prefixes: list[str] = []
388+
for name, submodule in module.named_modules():
389+
if isinstance(submodule, OptimizedModule):
390+
compiled_prefixes.append(name)
391+
392+
if not compiled_prefixes:
393+
return state_dict
394+
395+
# Sort by length descending so longer prefixes are matched first
396+
compiled_prefixes.sort(key=len, reverse=True)
397+
398+
def _transform_key(key: str) -> str:
399+
for prefix in compiled_prefixes:
400+
# Check if key starts with "prefix." (the compiled module path)
401+
if key == prefix or key.startswith(prefix + "."):
402+
suffix = key[len(prefix) :] # e.g., ".0.weight" or ""
403+
# Insert _orig_mod between prefix and rest
404+
return f"{prefix}._orig_mod{suffix}"
405+
return key
406+
407+
# Transform keys in "state" section of the optimizer state dict
408+
if "state" in state_dict:
409+
new_state = {_transform_key(k): v for k, v in state_dict["state"].items()}
410+
state_dict = {**state_dict, "state": new_state}
411+
412+
return state_dict

tests/tests_pytorch/strategies/test_model_parallel.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,74 @@ def configure_model(self) -> None:
251251
strategy.setup(Mock())
252252
assert all(not p.is_meta for p in model.parameters())
253253
assert all(not b.is_meta for b in model.buffers())
254+
255+
256+
@RunIf(min_torch="2.4")
257+
def test_align_compiled_param_names_with_module():
258+
"""Test that optimizer state dict keys are aligned with compiled submodule parameter names."""
259+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
260+
261+
class SimpleModule(nn.Module):
262+
def __init__(self):
263+
super().__init__()
264+
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
265+
266+
def forward(self, x):
267+
return self.model(x)
268+
269+
# Test with compiled submodule
270+
m = SimpleModule()
271+
m.model = torch.compile(m.model)
272+
273+
# Simulate optimizer state dict without _orig_mod in keys
274+
state_dict = {
275+
"state": {
276+
"model.0.weight": {"step": 1},
277+
"model.0.bias": {"step": 1},
278+
"model.2.weight": {"step": 1},
279+
"model.2.bias": {"step": 1},
280+
}
281+
}
282+
283+
result = _align_compiled_param_names_with_module(state_dict, m)
284+
285+
# Verify keys now have _orig_mod inserted
286+
expected_keys = {
287+
"model._orig_mod.0.weight",
288+
"model._orig_mod.0.bias",
289+
"model._orig_mod.2.weight",
290+
"model._orig_mod.2.bias",
291+
}
292+
assert set(result["state"].keys()) == expected_keys
293+
294+
# Verify they match the module's named_parameters
295+
param_names = {name for name, _ in m.named_parameters()}
296+
assert set(result["state"].keys()) == param_names
297+
298+
299+
@RunIf(min_torch="2.4")
300+
def test_align_compiled_param_names_no_compile():
301+
"""Test that non-compiled modules pass through unchanged."""
302+
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module
303+
304+
class SimpleModule(nn.Module):
305+
def __init__(self):
306+
super().__init__()
307+
self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32))
308+
309+
def forward(self, x):
310+
return self.model(x)
311+
312+
m = SimpleModule() # Not compiled
313+
314+
state_dict = {
315+
"state": {
316+
"model.0.weight": {"step": 1},
317+
"model.0.bias": {"step": 1},
318+
}
319+
}
320+
321+
result = _align_compiled_param_names_with_module(state_dict, m)
322+
323+
# Keys should be unchanged
324+
assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"}

0 commit comments

Comments
 (0)