Skip to content

Commit d73c50d

Browse files
authored
Fix missing reset when pruning with lottery ticket hypothesis (#21191)
* fix pruning hasattr * add testing * changelog
1 parent f4e0a19 commit d73c50d

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))
3838

3939

40+
- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))
41+
4042
---
4143

4244
## [2.5.5] - 2025-09-05

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
277277

278278
@staticmethod
279279
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
280-
dst = getattr(new, name)
280+
# Check if the parameter has been pruned (has _orig suffix)
281+
dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name)
281282
src = getattr(old, name)
282283
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
283284
return

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ def apply_lottery_ticket_hypothesis(self):
205205
for i, name in names:
206206
curr, curr_name = self._parameters_to_prune[i]
207207
assert name == curr_name
208-
actual, expected = getattr(curr, name).data, getattr(copy, name).data
208+
# Check weight_orig if parameter is pruned, otherwise check the parameter directly
209+
if hasattr(curr, name + "_orig"):
210+
actual = getattr(curr, name + "_orig").data
211+
else:
212+
actual = getattr(curr, name).data
213+
expected = getattr(copy, name).data
209214
allclose = torch.allclose(actual.cpu(), expected)
210215
assert not allclose if self._resample_parameters else allclose
211216

@@ -405,3 +410,56 @@ def __init__(self):
405410
for module, param_name in parameters_to_prune:
406411
param = getattr(module, param_name)
407412
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
413+
414+
415+
def test_lottery_ticket_hypothesis_correctly_reset(tmp_path):
416+
"""Test that lottery ticket hypothesis correctly resets unpruned weights to original values."""
417+
seed_everything(42)
418+
419+
class LTHTestModel(BoringModel):
420+
def __init__(self):
421+
super().__init__()
422+
self.layer = nn.Linear(32, 2, bias=False)
423+
with torch.no_grad():
424+
# Initialize with a simple pattern for verification
425+
self.layer.weight.copy_(torch.arange(1, 65, dtype=torch.float32).reshape(2, 32))
426+
427+
model = LTHTestModel()
428+
original_weights = model.layer.weight.data.clone()
429+
430+
# Create a pruning callback that applies both pruning and LTH at epoch 1
431+
pruning_callback = ModelPruning(
432+
"l1_unstructured",
433+
parameters_to_prune=[(model.layer, "weight")],
434+
use_lottery_ticket_hypothesis=lambda epoch: epoch == 1,
435+
amount=0.5,
436+
verbose=0, # Reduce verbosity
437+
make_pruning_permanent=False,
438+
apply_pruning=lambda epoch: epoch == 1,
439+
)
440+
441+
trainer = Trainer(
442+
default_root_dir=tmp_path,
443+
enable_progress_bar=False,
444+
enable_model_summary=False,
445+
enable_checkpointing=False,
446+
logger=False,
447+
limit_train_batches=5,
448+
limit_val_batches=1,
449+
max_epochs=2,
450+
accelerator="cpu",
451+
callbacks=pruning_callback,
452+
)
453+
trainer.fit(model)
454+
455+
# After training with LTH applied, check that weight_orig was reset correctly
456+
assert hasattr(model.layer, "weight_mask"), "Pruning should have created weight_mask"
457+
assert hasattr(model.layer, "weight_orig"), "Pruning should have created weight_orig"
458+
459+
weight_orig = getattr(model.layer, "weight_orig")
460+
assert torch.allclose(weight_orig, original_weights, atol=1e-6), (
461+
f"Lottery ticket hypothesis failed. weight_orig should be reset to original values.\n"
462+
f"Expected weight_orig: {original_weights}\n"
463+
f"Actual weight_orig: {weight_orig}\n"
464+
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
465+
)

0 commit comments

Comments
 (0)