Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))


- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))

---

## [2.5.5] - 2025-09-05
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None:

@staticmethod
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
dst = getattr(new, name)
# Check if the parameter has been pruned (has _orig suffix)
dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name)
src = getattr(old, name)
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
return
Expand Down
60 changes: 59 additions & 1 deletion tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def apply_lottery_ticket_hypothesis(self):
for i, name in names:
curr, curr_name = self._parameters_to_prune[i]
assert name == curr_name
actual, expected = getattr(curr, name).data, getattr(copy, name).data
# Check weight_orig if parameter is pruned, otherwise check the parameter directly
if hasattr(curr, name + "_orig"):
actual = getattr(curr, name + "_orig").data
else:
actual = getattr(curr, name).data
expected = getattr(copy, name).data
allclose = torch.allclose(actual.cpu(), expected)
assert not allclose if self._resample_parameters else allclose

Expand Down Expand Up @@ -405,3 +410,56 @@ def __init__(self):
for module, param_name in parameters_to_prune:
param = getattr(module, param_name)
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"


def test_lottery_ticket_hypothesis_correctly_reset(tmp_path):
"""Test that lottery ticket hypothesis correctly resets unpruned weights to original values."""
seed_everything(42)

class LTHTestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = nn.Linear(32, 2, bias=False)
with torch.no_grad():
# Initialize with a simple pattern for verification
self.layer.weight.copy_(torch.arange(1, 65, dtype=torch.float32).reshape(2, 32))

model = LTHTestModel()
original_weights = model.layer.weight.data.clone()

# Create a pruning callback that applies both pruning and LTH at epoch 1
pruning_callback = ModelPruning(
"l1_unstructured",
parameters_to_prune=[(model.layer, "weight")],
use_lottery_ticket_hypothesis=lambda epoch: epoch == 1,
amount=0.5,
verbose=0, # Reduce verbosity
make_pruning_permanent=False,
apply_pruning=lambda epoch: epoch == 1,
)

trainer = Trainer(
default_root_dir=tmp_path,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=5,
limit_val_batches=1,
max_epochs=2,
accelerator="cpu",
callbacks=pruning_callback,
)
trainer.fit(model)

# After training with LTH applied, check that weight_orig was reset correctly
assert hasattr(model.layer, "weight_mask"), "Pruning should have created weight_mask"
assert hasattr(model.layer, "weight_orig"), "Pruning should have created weight_orig"

weight_orig = getattr(model.layer, "weight_orig")
assert torch.allclose(weight_orig, original_weights, atol=1e-6), (
f"Lottery ticket hypothesis failed. weight_orig should be reset to original values.\n"
f"Expected weight_orig: {original_weights}\n"
f"Actual weight_orig: {weight_orig}\n"
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
)
Loading