diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 28e1a60b4ae4b..94714421422a0 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071)) +- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 8b618ae2be912..b88ee90bde38d 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -87,6 +87,8 @@ def __init__( self._throughputs: dict[RunningStage, Throughput] = {} self._t0s: dict[RunningStage, float] = {} self._lengths: dict[RunningStage, int] = {} + self._samples: dict[RunningStage, int] = {} + self._batches: dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: @@ -106,8 +108,13 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> def _start(self, trainer: "Trainer") -> None: stage = trainer.state.stage assert stage is not None - self._throughputs[stage].reset() - self._lengths[stage] = 0 + + if stage not in self._samples: + self._throughputs[stage].reset() + self._lengths[stage] = 0 + self._samples[stage] = 0 + self._batches[stage] = 0 + self._t0s[stage] = time.perf_counter() @torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads @@ -133,12 +140,14 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, ) flops_per_batch = None - batch_size = self.batch_size_fn(batch) + self._samples[stage] += self.batch_size_fn(batch) + self._batches[stage] += 1 + throughput.update( time=elapsed, - batches=iter_num, + batches=self._batches[stage], # this assumes that all iterations used the same batch size - samples=iter_num * batch_size, + samples=self._samples[stage], lengths=None if self.length_fn is None else self._lengths[stage], flops=flops_per_batch, # type: ignore[arg-type] ) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 83bcb16c81797..7dda7875a43c7 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn): call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9), call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12), ] + + +def test_throughput_monitor_variable_batch_size(tmp_path): + """Test that ThroughputMonitor correctly handles variable batch sizes.""" + logger_mock = Mock() + logger_mock.save_dir = tmp_path + + # Simulate variable batch sizes by tracking calls + batch_sizes = [1, 3, 2, 1, 4] + call_count = [0] + + def variable_batch_size_fn(batch): + # Return the predefined batch size for this call + current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)] + call_count[0] += 1 + return current_batch_size + + monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|") + + model = BoringModel() + model.flops_per_batch = 10 + + trainer = Trainer( + devices=1, + logger=logger_mock, + callbacks=monitor, + max_steps=len(batch_sizes), + log_every_n_steps=1, + limit_val_batches=0, + num_sanity_val_steps=0, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)] + + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), + ): + trainer.fit(model) + + log_calls = logger_mock.log_metrics.call_args_list + assert len(log_calls) == len(batch_sizes) + + # Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4) + expected_cumulative_samples = [1, 4, 6, 7, 11] + + for i, log_call in enumerate(log_calls): + metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0] + expected_samples = expected_cumulative_samples[i] + assert metrics["train|samples"] == expected_samples, ( + f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}" + ) + assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}" + + +def test_throughput_monitor_variable_batch_size_with_validation(tmp_path): + """Test variable batch sizes with validation to ensure stage isolation.""" + logger_mock = Mock() + logger_mock.save_dir = tmp_path + + train_batch_sizes = [2, 1, 3] + val_batch_sizes = [1, 2] + train_call_count = [0] + val_call_count = [0] + + def variable_batch_size_fn(batch): + if hasattr(batch, "size") and batch.size(0) > 0: + if train_call_count[0] < len(train_batch_sizes): + current_batch_size = train_batch_sizes[train_call_count[0]] + train_call_count[0] += 1 + return current_batch_size + current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)] + val_call_count[0] += 1 + return current_batch_size + return 1 + + monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3) + model = BoringModel() + + trainer = Trainer( + devices=1, + logger=logger_mock, + callbacks=monitor, + max_steps=len(train_batch_sizes), + log_every_n_steps=1, + limit_val_batches=2, + val_check_interval=2, + num_sanity_val_steps=0, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100): + trainer.fit(model) + + # Verify that both training and validation metrics were logged + log_calls = logger_mock.log_metrics.call_args_list + train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)] + val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)] + + assert len(train_calls) > 0, "Expected training metrics to be logged" + assert len(val_calls) > 0, "Expected validation metrics to be logged" + train_samples = [] + for train_call in train_calls: + metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {}) + if "train/samples" in metrics: + train_samples.append(metrics["train/samples"]) + elif "train|samples" in metrics: + train_samples.append(metrics["train|samples"])