Skip to content

Commit d7f15d3

Browse files
authored
Merge branch 'master' into bugfix_lightningcli_missing_parent_callback
2 parents 53e04e1 + d73c50d commit d7f15d3

File tree

5 files changed

+195
-7
lines changed

5 files changed

+195
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
2020

2121

22+
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
23+
24+
2225
### Changed
2326

2427
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
@@ -36,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3639

3740
- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192))
3841

42+
43+
- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))
44+
45+
3946
---
4047

4148
## [2.5.5] - 2025-09-05

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
277277

278278
@staticmethod
279279
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
280-
dst = getattr(new, name)
280+
# Check if the parameter has been pruned (has _orig suffix)
281+
dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name)
281282
src = getattr(old, name)
282283
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
283284
return

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
self._throughputs: dict[RunningStage, Throughput] = {}
8888
self._t0s: dict[RunningStage, float] = {}
8989
self._lengths: dict[RunningStage, int] = {}
90+
self._samples: dict[RunningStage, int] = {}
91+
self._batches: dict[RunningStage, int] = {}
9092

9193
@override
9294
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
@@ -106,8 +108,13 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
106108
def _start(self, trainer: "Trainer") -> None:
107109
stage = trainer.state.stage
108110
assert stage is not None
109-
self._throughputs[stage].reset()
110-
self._lengths[stage] = 0
111+
112+
if stage not in self._samples:
113+
self._throughputs[stage].reset()
114+
self._lengths[stage] = 0
115+
self._samples[stage] = 0
116+
self._batches[stage] = 0
117+
111118
self._t0s[stage] = time.perf_counter()
112119

113120
@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
@@ -133,12 +140,14 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
133140
)
134141
flops_per_batch = None
135142

136-
batch_size = self.batch_size_fn(batch)
143+
self._samples[stage] += self.batch_size_fn(batch)
144+
self._batches[stage] += 1
145+
137146
throughput.update(
138147
time=elapsed,
139-
batches=iter_num,
148+
batches=self._batches[stage],
140149
# this assumes that all iterations used the same batch size
141-
samples=iter_num * batch_size,
150+
samples=self._samples[stage],
142151
lengths=None if self.length_fn is None else self._lengths[stage],
143152
flops=flops_per_batch, # type: ignore[arg-type]
144153
)

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ def apply_lottery_ticket_hypothesis(self):
205205
for i, name in names:
206206
curr, curr_name = self._parameters_to_prune[i]
207207
assert name == curr_name
208-
actual, expected = getattr(curr, name).data, getattr(copy, name).data
208+
# Check weight_orig if parameter is pruned, otherwise check the parameter directly
209+
if hasattr(curr, name + "_orig"):
210+
actual = getattr(curr, name + "_orig").data
211+
else:
212+
actual = getattr(curr, name).data
213+
expected = getattr(copy, name).data
209214
allclose = torch.allclose(actual.cpu(), expected)
210215
assert not allclose if self._resample_parameters else allclose
211216

@@ -405,3 +410,56 @@ def __init__(self):
405410
for module, param_name in parameters_to_prune:
406411
param = getattr(module, param_name)
407412
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
413+
414+
415+
def test_lottery_ticket_hypothesis_correctly_reset(tmp_path):
416+
"""Test that lottery ticket hypothesis correctly resets unpruned weights to original values."""
417+
seed_everything(42)
418+
419+
class LTHTestModel(BoringModel):
420+
def __init__(self):
421+
super().__init__()
422+
self.layer = nn.Linear(32, 2, bias=False)
423+
with torch.no_grad():
424+
# Initialize with a simple pattern for verification
425+
self.layer.weight.copy_(torch.arange(1, 65, dtype=torch.float32).reshape(2, 32))
426+
427+
model = LTHTestModel()
428+
original_weights = model.layer.weight.data.clone()
429+
430+
# Create a pruning callback that applies both pruning and LTH at epoch 1
431+
pruning_callback = ModelPruning(
432+
"l1_unstructured",
433+
parameters_to_prune=[(model.layer, "weight")],
434+
use_lottery_ticket_hypothesis=lambda epoch: epoch == 1,
435+
amount=0.5,
436+
verbose=0, # Reduce verbosity
437+
make_pruning_permanent=False,
438+
apply_pruning=lambda epoch: epoch == 1,
439+
)
440+
441+
trainer = Trainer(
442+
default_root_dir=tmp_path,
443+
enable_progress_bar=False,
444+
enable_model_summary=False,
445+
enable_checkpointing=False,
446+
logger=False,
447+
limit_train_batches=5,
448+
limit_val_batches=1,
449+
max_epochs=2,
450+
accelerator="cpu",
451+
callbacks=pruning_callback,
452+
)
453+
trainer.fit(model)
454+
455+
# After training with LTH applied, check that weight_orig was reset correctly
456+
assert hasattr(model.layer, "weight_mask"), "Pruning should have created weight_mask"
457+
assert hasattr(model.layer, "weight_orig"), "Pruning should have created weight_orig"
458+
459+
weight_orig = getattr(model.layer, "weight_orig")
460+
assert torch.allclose(weight_orig, original_weights, atol=1e-6), (
461+
f"Lottery ticket hypothesis failed. weight_orig should be reset to original values.\n"
462+
f"Expected weight_orig: {original_weights}\n"
463+
f"Actual weight_orig: {weight_orig}\n"
464+
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
465+
)

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn):
307307
call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9),
308308
call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12),
309309
]
310+
311+
312+
def test_throughput_monitor_variable_batch_size(tmp_path):
313+
"""Test that ThroughputMonitor correctly handles variable batch sizes."""
314+
logger_mock = Mock()
315+
logger_mock.save_dir = tmp_path
316+
317+
# Simulate variable batch sizes by tracking calls
318+
batch_sizes = [1, 3, 2, 1, 4]
319+
call_count = [0]
320+
321+
def variable_batch_size_fn(batch):
322+
# Return the predefined batch size for this call
323+
current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)]
324+
call_count[0] += 1
325+
return current_batch_size
326+
327+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|")
328+
329+
model = BoringModel()
330+
model.flops_per_batch = 10
331+
332+
trainer = Trainer(
333+
devices=1,
334+
logger=logger_mock,
335+
callbacks=monitor,
336+
max_steps=len(batch_sizes),
337+
log_every_n_steps=1,
338+
limit_val_batches=0,
339+
num_sanity_val_steps=0,
340+
enable_checkpointing=False,
341+
enable_model_summary=False,
342+
enable_progress_bar=False,
343+
)
344+
345+
timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)]
346+
347+
with (
348+
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
349+
mock.patch("time.perf_counter", side_effect=timings),
350+
):
351+
trainer.fit(model)
352+
353+
log_calls = logger_mock.log_metrics.call_args_list
354+
assert len(log_calls) == len(batch_sizes)
355+
356+
# Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4)
357+
expected_cumulative_samples = [1, 4, 6, 7, 11]
358+
359+
for i, log_call in enumerate(log_calls):
360+
metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0]
361+
expected_samples = expected_cumulative_samples[i]
362+
assert metrics["train|samples"] == expected_samples, (
363+
f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}"
364+
)
365+
assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}"
366+
367+
368+
def test_throughput_monitor_variable_batch_size_with_validation(tmp_path):
369+
"""Test variable batch sizes with validation to ensure stage isolation."""
370+
logger_mock = Mock()
371+
logger_mock.save_dir = tmp_path
372+
373+
train_batch_sizes = [2, 1, 3]
374+
val_batch_sizes = [1, 2]
375+
train_call_count = [0]
376+
val_call_count = [0]
377+
378+
def variable_batch_size_fn(batch):
379+
if hasattr(batch, "size") and batch.size(0) > 0:
380+
if train_call_count[0] < len(train_batch_sizes):
381+
current_batch_size = train_batch_sizes[train_call_count[0]]
382+
train_call_count[0] += 1
383+
return current_batch_size
384+
current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)]
385+
val_call_count[0] += 1
386+
return current_batch_size
387+
return 1
388+
389+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3)
390+
model = BoringModel()
391+
392+
trainer = Trainer(
393+
devices=1,
394+
logger=logger_mock,
395+
callbacks=monitor,
396+
max_steps=len(train_batch_sizes),
397+
log_every_n_steps=1,
398+
limit_val_batches=2,
399+
val_check_interval=2,
400+
num_sanity_val_steps=0,
401+
enable_checkpointing=False,
402+
enable_model_summary=False,
403+
enable_progress_bar=False,
404+
)
405+
406+
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100):
407+
trainer.fit(model)
408+
409+
# Verify that both training and validation metrics were logged
410+
log_calls = logger_mock.log_metrics.call_args_list
411+
train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)]
412+
val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)]
413+
414+
assert len(train_calls) > 0, "Expected training metrics to be logged"
415+
assert len(val_calls) > 0, "Expected validation metrics to be logged"
416+
train_samples = []
417+
for train_call in train_calls:
418+
metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {})
419+
if "train/samples" in metrics:
420+
train_samples.append(metrics["train/samples"])
421+
elif "train|samples" in metrics:
422+
train_samples.append(metrics["train|samples"])

0 commit comments

Comments
 (0)