Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

---

## [Unreleased] - YYYY-MM-DD

### Added

-

### Changed

-
### Fixed

### Removed
- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))

-
---


## [2.6.0] - 2025-11-28
Expand Down
54 changes: 53 additions & 1 deletion src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]:

state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
if not self._save_distributed_checkpoint and self.global_rank == 0:
# Store the optimizer state dict in standard format
state_dict = _align_compiled_param_names_with_module(state_dict, self.model)
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
return state_dict

Expand Down Expand Up @@ -366,3 +366,55 @@ def set_world_ranks(self) -> None:
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank


def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]:
"""Align optimizer state dict keys with a module that may have compiled submodules.

When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``.
For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer
state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod``
prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from
``module.named_parameters()``.

This function inserts ``._orig_mod`` into the state dict keys where necessary so that
they match the module's ``named_parameters()`` output.

"""
from torch._dynamo import OptimizedModule

# Build set of compiled submodule prefixes (e.g., "model" if model is compiled)
compiled_prefixes: list[str] = []
for name, submodule in module.named_modules():
if isinstance(submodule, OptimizedModule):
compiled_prefixes.append(name)

if not compiled_prefixes:
return state_dict

# Sort by length descending so longer prefixes are matched first
compiled_prefixes.sort(key=len, reverse=True)

def _transform_key(key: str) -> str:
for prefix in compiled_prefixes:
# Check if key starts with "prefix." (the compiled module path)
if key == prefix or key.startswith(prefix + "."):
suffix = key[len(prefix) :] # e.g., ".0.weight" or ""
# Insert _orig_mod between prefix and rest
return f"{prefix}._orig_mod{suffix}"
return key

# Transform keys in "state" section of the optimizer state dict
if "state" in state_dict:
new_state = {_transform_key(k): v for k, v in state_dict["state"].items()}
state_dict = {**state_dict, "state": new_state}

# Transform param names in "param_groups" section
if "param_groups" in state_dict:
new_param_groups = []
for group in state_dict["param_groups"]:
new_group = {**group, "params": [_transform_key(p) for p in group["params"]]}
new_param_groups.append(new_group)
state_dict = {**state_dict, "param_groups": new_param_groups}

return state_dict
75 changes: 75 additions & 0 deletions tests/tests_pytorch/strategies/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,78 @@ def configure_model(self) -> None:
strategy.setup(Mock())
assert all(not p.is_meta for p in model.parameters())
assert all(not b.is_meta for b in model.buffers())


@RunIf(min_torch="2.4")
def test_align_compiled_param_names_with_module():
"""Test that optimizer state dict keys are aligned with compiled submodule parameter names."""
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module

class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))

def forward(self, x):
return self.model(x)

# Test with compiled submodule
m = SimpleModule()
m.model = torch.compile(m.model)

# Simulate optimizer state dict without _orig_mod in keys (includes both state and param_groups)
state_dict = {
"state": {
"model.0.weight": {"step": 1},
"model.0.bias": {"step": 1},
"model.2.weight": {"step": 1},
"model.2.bias": {"step": 1},
},
"param_groups": [{"params": ["model.0.weight", "model.0.bias", "model.2.weight", "model.2.bias"], "lr": 0.01}],
}

result = _align_compiled_param_names_with_module(state_dict, m)

# Verify state keys now have _orig_mod inserted
expected_keys = {
"model._orig_mod.0.weight",
"model._orig_mod.0.bias",
"model._orig_mod.2.weight",
"model._orig_mod.2.bias",
}
assert set(result["state"].keys()) == expected_keys

# Verify param_groups params also have _orig_mod inserted
assert set(result["param_groups"][0]["params"]) == expected_keys

# Verify they match the module's named_parameters
param_names = {name for name, _ in m.named_parameters()}
assert set(result["state"].keys()) == param_names


@RunIf(min_torch="2.4")
def test_align_compiled_param_names_no_compile():
"""Test that non-compiled modules pass through unchanged."""
from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module

class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32))

def forward(self, x):
return self.model(x)

m = SimpleModule() # Not compiled

state_dict = {
"state": {
"model.0.weight": {"step": 1},
"model.0.bias": {"step": 1},
}
}

result = _align_compiled_param_names_with_module(state_dict, m)

# Keys should be unchanged
assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,33 @@ def configure_model(self):
parallelize(self.model, device_mesh=self.device_mesh)


class SimpleCompiledModule(LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
self._loss = nn.MSELoss()

def configure_model(self):
self.model = torch.compile(self.model)

def training_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
return self._loss(preds, y)

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3)


def _compiled_model_dataloader(batch_size: int = 32, num_batches: int = 2):
total_samples = batch_size * num_batches
generator = torch.Generator().manual_seed(0)
features = torch.randn(total_samples, 32, generator=generator)
targets = torch.randn(total_samples, 32, generator=generator)
dataset = torch.utils.data.TensorDataset(features, targets)
return DataLoader(dataset, batch_size=batch_size)


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh(distributed):
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -237,6 +264,44 @@ def training_step(self, batch):
trainer.fit(model)


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
"""Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing."""

seed_everything(0)
strategy = ModelParallelStrategy(
data_parallel_size=1,
tensor_parallel_size=1,
save_distributed_checkpoint=False,
)

trainer = Trainer(
accelerator="auto",
devices=1,
strategy=strategy,
max_steps=2,
limit_train_batches=2,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
default_root_dir=tmp_path,
)

dataloader = _compiled_model_dataloader(batch_size=32, num_batches=2)

with trainer.init_module(empty_init=True):
model = SimpleCompiledModule()

trainer.fit(model, dataloader)

if trainer.is_global_zero:
checkpoint_path = tmp_path / "compiled-model.ckpt"
trainer.save_checkpoint(checkpoint_path)

trainer.strategy.barrier()


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
@pytest.mark.parametrize(
"compile",
Expand Down
Loading