diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..f8be6c68cd71a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,19 +6,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- -## [Unreleased] - YYYY-MM-DD - -### Added - -- - -### Changed - -- +### Fixed -### Removed +- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357)) -- +--- ## [2.6.0] - 2025-11-28 diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index f3165a08e6bdd..4ad8b08c7d2ea 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -286,7 +286,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]: state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options) if not self._save_distributed_checkpoint and self.global_rank == 0: - # Store the optimizer state dict in standard format + state_dict = _align_compiled_param_names_with_module(state_dict, self.model) state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model) return state_dict @@ -366,3 +366,55 @@ def set_world_ranks(self) -> None: # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank + + +def _align_compiled_param_names_with_module(state_dict: dict[str, Any], module: torch.nn.Module) -> dict[str, Any]: + """Align optimizer state dict keys with a module that may have compiled submodules. + + When ``torch.compile`` wraps a submodule, its parameters appear under ``_orig_mod``. + For example, ``model.0.weight`` becomes ``model._orig_mod.0.weight``. The optimizer + state dict returned by ``get_optimizer_state_dict`` may not include the ``_orig_mod`` + prefix, causing a mismatch when ``rekey_optim_state_dict`` builds its mapping from + ``module.named_parameters()``. + + This function inserts ``._orig_mod`` into the state dict keys where necessary so that + they match the module's ``named_parameters()`` output. + + """ + from torch._dynamo import OptimizedModule + + # Build set of compiled submodule prefixes (e.g., "model" if model is compiled) + compiled_prefixes: list[str] = [] + for name, submodule in module.named_modules(): + if isinstance(submodule, OptimizedModule): + compiled_prefixes.append(name) + + if not compiled_prefixes: + return state_dict + + # Sort by length descending so longer prefixes are matched first + compiled_prefixes.sort(key=len, reverse=True) + + def _transform_key(key: str) -> str: + for prefix in compiled_prefixes: + # Check if key starts with "prefix." (the compiled module path) + if key == prefix or key.startswith(prefix + "."): + suffix = key[len(prefix) :] # e.g., ".0.weight" or "" + # Insert _orig_mod between prefix and rest + return f"{prefix}._orig_mod{suffix}" + return key + + # Transform keys in "state" section of the optimizer state dict + if "state" in state_dict: + new_state = {_transform_key(k): v for k, v in state_dict["state"].items()} + state_dict = {**state_dict, "state": new_state} + + # Transform param names in "param_groups" section + if "param_groups" in state_dict: + new_param_groups = [] + for group in state_dict["param_groups"]: + new_group = {**group, "params": [_transform_key(p) for p in group["params"]]} + new_param_groups.append(new_group) + state_dict = {**state_dict, "param_groups": new_param_groups} + + return state_dict diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index c803c10afa4b4..33cf45326290c 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -251,3 +251,78 @@ def configure_model(self) -> None: strategy.setup(Mock()) assert all(not p.is_meta for p in model.parameters()) assert all(not b.is_meta for b in model.buffers()) + + +@RunIf(min_torch="2.4") +def test_align_compiled_param_names_with_module(): + """Test that optimizer state dict keys are aligned with compiled submodule parameter names.""" + from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module + + class SimpleModule(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32)) + + def forward(self, x): + return self.model(x) + + # Test with compiled submodule + m = SimpleModule() + m.model = torch.compile(m.model) + + # Simulate optimizer state dict without _orig_mod in keys (includes both state and param_groups) + state_dict = { + "state": { + "model.0.weight": {"step": 1}, + "model.0.bias": {"step": 1}, + "model.2.weight": {"step": 1}, + "model.2.bias": {"step": 1}, + }, + "param_groups": [{"params": ["model.0.weight", "model.0.bias", "model.2.weight", "model.2.bias"], "lr": 0.01}], + } + + result = _align_compiled_param_names_with_module(state_dict, m) + + # Verify state keys now have _orig_mod inserted + expected_keys = { + "model._orig_mod.0.weight", + "model._orig_mod.0.bias", + "model._orig_mod.2.weight", + "model._orig_mod.2.bias", + } + assert set(result["state"].keys()) == expected_keys + + # Verify param_groups params also have _orig_mod inserted + assert set(result["param_groups"][0]["params"]) == expected_keys + + # Verify they match the module's named_parameters + param_names = {name for name, _ in m.named_parameters()} + assert set(result["state"].keys()) == param_names + + +@RunIf(min_torch="2.4") +def test_align_compiled_param_names_no_compile(): + """Test that non-compiled modules pass through unchanged.""" + from lightning.pytorch.strategies.model_parallel import _align_compiled_param_names_with_module + + class SimpleModule(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential(nn.Linear(32, 64), nn.Linear(64, 32)) + + def forward(self, x): + return self.model(x) + + m = SimpleModule() # Not compiled + + state_dict = { + "state": { + "model.0.weight": {"step": 1}, + "model.0.bias": {"step": 1}, + } + } + + result = _align_compiled_param_names_with_module(state_dict, m) + + # Keys should be unchanged + assert set(result["state"].keys()) == {"model.0.weight", "model.0.bias"} diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 4b3dbe9df9724..48ce013d22d3d 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -135,6 +135,33 @@ def configure_model(self): parallelize(self.model, device_mesh=self.device_mesh) +class SimpleCompiledModule(LightningModule): + def __init__(self): + super().__init__() + self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32)) + self._loss = nn.MSELoss() + + def configure_model(self): + self.model = torch.compile(self.model) + + def training_step(self, batch, batch_idx): + x, y = batch + preds = self.model(x) + return self._loss(preds, y) + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=1e-3) + + +def _compiled_model_dataloader(batch_size: int = 32, num_batches: int = 2): + total_samples = batch_size * num_batches + generator = torch.Generator().manual_seed(0) + features = torch.randn(total_samples, 32, generator=generator) + targets = torch.randn(total_samples, 32, generator=generator) + dataset = torch.utils.data.TensorDataset(features, targets) + return DataLoader(dataset, batch_size=batch_size) + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh @@ -237,6 +264,44 @@ def training_step(self, batch): trainer.fit(model) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) +def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path): + """Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing.""" + + seed_everything(0) + strategy = ModelParallelStrategy( + data_parallel_size=1, + tensor_parallel_size=1, + save_distributed_checkpoint=False, + ) + + trainer = Trainer( + accelerator="auto", + devices=1, + strategy=strategy, + max_steps=2, + limit_train_batches=2, + enable_checkpointing=False, + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + default_root_dir=tmp_path, + ) + + dataloader = _compiled_model_dataloader(batch_size=32, num_batches=2) + + with trainer.init_module(empty_init=True): + model = SimpleCompiledModule() + + trainer.fit(model, dataloader) + + if trainer.is_global_zero: + checkpoint_path = tmp_path / "compiled-model.ckpt" + trainer.save_checkpoint(checkpoint_path) + + trainer.strategy.barrier() + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) @pytest.mark.parametrize( "compile",