diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 28e1a60b4ae4b..f55ef96703604 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187)) +- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191)) + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 93e828db0d24f..1de693978acfa 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None: @staticmethod def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None: - dst = getattr(new, name) + # Check if the parameter has been pruned (has _orig suffix) + dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name) src = getattr(old, name) if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor): return diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 6efe9b9992d00..20bb03bfdd941 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -205,7 +205,12 @@ def apply_lottery_ticket_hypothesis(self): for i, name in names: curr, curr_name = self._parameters_to_prune[i] assert name == curr_name - actual, expected = getattr(curr, name).data, getattr(copy, name).data + # Check weight_orig if parameter is pruned, otherwise check the parameter directly + if hasattr(curr, name + "_orig"): + actual = getattr(curr, name + "_orig").data + else: + actual = getattr(curr, name).data + expected = getattr(copy, name).data allclose = torch.allclose(actual.cpu(), expected) assert not allclose if self._resample_parameters else allclose @@ -405,3 +410,56 @@ def __init__(self): for module, param_name in parameters_to_prune: param = getattr(module, param_name) assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}" + + +def test_lottery_ticket_hypothesis_correctly_reset(tmp_path): + """Test that lottery ticket hypothesis correctly resets unpruned weights to original values.""" + seed_everything(42) + + class LTHTestModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = nn.Linear(32, 2, bias=False) + with torch.no_grad(): + # Initialize with a simple pattern for verification + self.layer.weight.copy_(torch.arange(1, 65, dtype=torch.float32).reshape(2, 32)) + + model = LTHTestModel() + original_weights = model.layer.weight.data.clone() + + # Create a pruning callback that applies both pruning and LTH at epoch 1 + pruning_callback = ModelPruning( + "l1_unstructured", + parameters_to_prune=[(model.layer, "weight")], + use_lottery_ticket_hypothesis=lambda epoch: epoch == 1, + amount=0.5, + verbose=0, # Reduce verbosity + make_pruning_permanent=False, + apply_pruning=lambda epoch: epoch == 1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + enable_progress_bar=False, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + limit_train_batches=5, + limit_val_batches=1, + max_epochs=2, + accelerator="cpu", + callbacks=pruning_callback, + ) + trainer.fit(model) + + # After training with LTH applied, check that weight_orig was reset correctly + assert hasattr(model.layer, "weight_mask"), "Pruning should have created weight_mask" + assert hasattr(model.layer, "weight_orig"), "Pruning should have created weight_orig" + + weight_orig = getattr(model.layer, "weight_orig") + assert torch.allclose(weight_orig, original_weights, atol=1e-6), ( + f"Lottery ticket hypothesis failed. weight_orig should be reset to original values.\n" + f"Expected weight_orig: {original_weights}\n" + f"Actual weight_orig: {weight_orig}\n" + f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}" + )