Skip to content

Commit f4e0a19

Browse files
alex-hhSkafteNicki
andauthored
Support variable batch size in throughput callback (#20236)
* Support variable batch size in throughput callback * fix implementation * add testing * changelog --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 3998b5d commit f4e0a19

File tree

3 files changed

+130
-5
lines changed

3 files changed

+130
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
2020

2121

22+
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
23+
24+
2225
### Changed
2326

2427
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
self._throughputs: dict[RunningStage, Throughput] = {}
8888
self._t0s: dict[RunningStage, float] = {}
8989
self._lengths: dict[RunningStage, int] = {}
90+
self._samples: dict[RunningStage, int] = {}
91+
self._batches: dict[RunningStage, int] = {}
9092

9193
@override
9294
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
@@ -106,8 +108,13 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
106108
def _start(self, trainer: "Trainer") -> None:
107109
stage = trainer.state.stage
108110
assert stage is not None
109-
self._throughputs[stage].reset()
110-
self._lengths[stage] = 0
111+
112+
if stage not in self._samples:
113+
self._throughputs[stage].reset()
114+
self._lengths[stage] = 0
115+
self._samples[stage] = 0
116+
self._batches[stage] = 0
117+
111118
self._t0s[stage] = time.perf_counter()
112119

113120
@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
@@ -133,12 +140,14 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
133140
)
134141
flops_per_batch = None
135142

136-
batch_size = self.batch_size_fn(batch)
143+
self._samples[stage] += self.batch_size_fn(batch)
144+
self._batches[stage] += 1
145+
137146
throughput.update(
138147
time=elapsed,
139-
batches=iter_num,
148+
batches=self._batches[stage],
140149
# this assumes that all iterations used the same batch size
141-
samples=iter_num * batch_size,
150+
samples=self._samples[stage],
142151
lengths=None if self.length_fn is None else self._lengths[stage],
143152
flops=flops_per_batch, # type: ignore[arg-type]
144153
)

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn):
307307
call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9),
308308
call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12),
309309
]
310+
311+
312+
def test_throughput_monitor_variable_batch_size(tmp_path):
313+
"""Test that ThroughputMonitor correctly handles variable batch sizes."""
314+
logger_mock = Mock()
315+
logger_mock.save_dir = tmp_path
316+
317+
# Simulate variable batch sizes by tracking calls
318+
batch_sizes = [1, 3, 2, 1, 4]
319+
call_count = [0]
320+
321+
def variable_batch_size_fn(batch):
322+
# Return the predefined batch size for this call
323+
current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)]
324+
call_count[0] += 1
325+
return current_batch_size
326+
327+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|")
328+
329+
model = BoringModel()
330+
model.flops_per_batch = 10
331+
332+
trainer = Trainer(
333+
devices=1,
334+
logger=logger_mock,
335+
callbacks=monitor,
336+
max_steps=len(batch_sizes),
337+
log_every_n_steps=1,
338+
limit_val_batches=0,
339+
num_sanity_val_steps=0,
340+
enable_checkpointing=False,
341+
enable_model_summary=False,
342+
enable_progress_bar=False,
343+
)
344+
345+
timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)]
346+
347+
with (
348+
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
349+
mock.patch("time.perf_counter", side_effect=timings),
350+
):
351+
trainer.fit(model)
352+
353+
log_calls = logger_mock.log_metrics.call_args_list
354+
assert len(log_calls) == len(batch_sizes)
355+
356+
# Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4)
357+
expected_cumulative_samples = [1, 4, 6, 7, 11]
358+
359+
for i, log_call in enumerate(log_calls):
360+
metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0]
361+
expected_samples = expected_cumulative_samples[i]
362+
assert metrics["train|samples"] == expected_samples, (
363+
f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}"
364+
)
365+
assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}"
366+
367+
368+
def test_throughput_monitor_variable_batch_size_with_validation(tmp_path):
369+
"""Test variable batch sizes with validation to ensure stage isolation."""
370+
logger_mock = Mock()
371+
logger_mock.save_dir = tmp_path
372+
373+
train_batch_sizes = [2, 1, 3]
374+
val_batch_sizes = [1, 2]
375+
train_call_count = [0]
376+
val_call_count = [0]
377+
378+
def variable_batch_size_fn(batch):
379+
if hasattr(batch, "size") and batch.size(0) > 0:
380+
if train_call_count[0] < len(train_batch_sizes):
381+
current_batch_size = train_batch_sizes[train_call_count[0]]
382+
train_call_count[0] += 1
383+
return current_batch_size
384+
current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)]
385+
val_call_count[0] += 1
386+
return current_batch_size
387+
return 1
388+
389+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3)
390+
model = BoringModel()
391+
392+
trainer = Trainer(
393+
devices=1,
394+
logger=logger_mock,
395+
callbacks=monitor,
396+
max_steps=len(train_batch_sizes),
397+
log_every_n_steps=1,
398+
limit_val_batches=2,
399+
val_check_interval=2,
400+
num_sanity_val_steps=0,
401+
enable_checkpointing=False,
402+
enable_model_summary=False,
403+
enable_progress_bar=False,
404+
)
405+
406+
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100):
407+
trainer.fit(model)
408+
409+
# Verify that both training and validation metrics were logged
410+
log_calls = logger_mock.log_metrics.call_args_list
411+
train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)]
412+
val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)]
413+
414+
assert len(train_calls) > 0, "Expected training metrics to be logged"
415+
assert len(val_calls) > 0, "Expected validation metrics to be logged"
416+
train_samples = []
417+
for train_call in train_calls:
418+
metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {})
419+
if "train/samples" in metrics:
420+
train_samples.append(metrics["train/samples"])
421+
elif "train|samples" in metrics:
422+
train_samples.append(metrics["train|samples"])

0 commit comments

Comments
 (0)