From 88b26becf9034e6d5963f098cf1f29eddd1a5386 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 29 Aug 2024 14:00:55 +0100 Subject: [PATCH 1/4] Support variable batch size in throughput callback --- src/lightning/pytorch/callbacks/throughput_monitor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a2d73d83184b1..5b4e9d39b4a06 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -87,6 +87,7 @@ def __init__( self._throughputs: Dict[RunningStage, Throughput] = {} self._t0s: Dict[RunningStage, float] = {} self._lengths: Dict[RunningStage, int] = {} + self._samples: Dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: @@ -109,6 +110,7 @@ def _start(self, trainer: "Trainer") -> None: self._throughputs[stage].reset() self._lengths[stage] = 0 self._t0s[stage] = time.perf_counter() + self._samples[stage] = 0 @torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, iter_num: int) -> None: @@ -133,12 +135,13 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, ) flops_per_batch = None - batch_size = self.batch_size_fn(batch) + self._samples[stage] += self.batch_size_fn(batch) + throughput.update( time=elapsed, batches=iter_num, # this assumes that all iterations used the same batch size - samples=iter_num * batch_size, + samples=self._samples[stage], lengths=None if self.length_fn is None else self._lengths[stage], flops=flops_per_batch, ) From f4ee1c77b5ee442ba51971947c35f0630d1f59d5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 10 Sep 2025 09:31:52 +0200 Subject: [PATCH 2/4] fix implementation --- .../pytorch/callbacks/throughput_monitor.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 64d35449d136d..b88ee90bde38d 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -88,6 +88,7 @@ def __init__( self._t0s: dict[RunningStage, float] = {} self._lengths: dict[RunningStage, int] = {} self._samples: dict[RunningStage, int] = {} + self._batches: dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: @@ -107,10 +108,14 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> def _start(self, trainer: "Trainer") -> None: stage = trainer.state.stage assert stage is not None - self._throughputs[stage].reset() - self._lengths[stage] = 0 + + if stage not in self._samples: + self._throughputs[stage].reset() + self._lengths[stage] = 0 + self._samples[stage] = 0 + self._batches[stage] = 0 + self._t0s[stage] = time.perf_counter() - self._samples[stage] = 0 @torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, iter_num: int) -> None: @@ -136,10 +141,11 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, flops_per_batch = None self._samples[stage] += self.batch_size_fn(batch) + self._batches[stage] += 1 throughput.update( time=elapsed, - batches=iter_num, + batches=self._batches[stage], # this assumes that all iterations used the same batch size samples=self._samples[stage], lengths=None if self.length_fn is None else self._lengths[stage], From 2c99674b4d5052d0afb5d7cd4fe113853bdc84f6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 10 Sep 2025 09:33:07 +0200 Subject: [PATCH 3/4] add testing --- .../callbacks/test_throughput_monitor.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 83bcb16c81797..7dda7875a43c7 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn): call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9), call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12), ] + + +def test_throughput_monitor_variable_batch_size(tmp_path): + """Test that ThroughputMonitor correctly handles variable batch sizes.""" + logger_mock = Mock() + logger_mock.save_dir = tmp_path + + # Simulate variable batch sizes by tracking calls + batch_sizes = [1, 3, 2, 1, 4] + call_count = [0] + + def variable_batch_size_fn(batch): + # Return the predefined batch size for this call + current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)] + call_count[0] += 1 + return current_batch_size + + monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|") + + model = BoringModel() + model.flops_per_batch = 10 + + trainer = Trainer( + devices=1, + logger=logger_mock, + callbacks=monitor, + max_steps=len(batch_sizes), + log_every_n_steps=1, + limit_val_batches=0, + num_sanity_val_steps=0, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)] + + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), + ): + trainer.fit(model) + + log_calls = logger_mock.log_metrics.call_args_list + assert len(log_calls) == len(batch_sizes) + + # Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4) + expected_cumulative_samples = [1, 4, 6, 7, 11] + + for i, log_call in enumerate(log_calls): + metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0] + expected_samples = expected_cumulative_samples[i] + assert metrics["train|samples"] == expected_samples, ( + f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}" + ) + assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}" + + +def test_throughput_monitor_variable_batch_size_with_validation(tmp_path): + """Test variable batch sizes with validation to ensure stage isolation.""" + logger_mock = Mock() + logger_mock.save_dir = tmp_path + + train_batch_sizes = [2, 1, 3] + val_batch_sizes = [1, 2] + train_call_count = [0] + val_call_count = [0] + + def variable_batch_size_fn(batch): + if hasattr(batch, "size") and batch.size(0) > 0: + if train_call_count[0] < len(train_batch_sizes): + current_batch_size = train_batch_sizes[train_call_count[0]] + train_call_count[0] += 1 + return current_batch_size + current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)] + val_call_count[0] += 1 + return current_batch_size + return 1 + + monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3) + model = BoringModel() + + trainer = Trainer( + devices=1, + logger=logger_mock, + callbacks=monitor, + max_steps=len(train_batch_sizes), + log_every_n_steps=1, + limit_val_batches=2, + val_check_interval=2, + num_sanity_val_steps=0, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100): + trainer.fit(model) + + # Verify that both training and validation metrics were logged + log_calls = logger_mock.log_metrics.call_args_list + train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)] + val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)] + + assert len(train_calls) > 0, "Expected training metrics to be logged" + assert len(val_calls) > 0, "Expected validation metrics to be logged" + train_samples = [] + for train_call in train_calls: + metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {}) + if "train/samples" in metrics: + train_samples.append(metrics["train/samples"]) + elif "train|samples" in metrics: + train_samples.append(metrics["train|samples"]) From cfe0de8569641ddc54adb8de67998a2a2af00396 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 10 Sep 2025 09:34:08 +0200 Subject: [PATCH 4/4] changelog --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 28e1a60b4ae4b..94714421422a0 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071)) +- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))