Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))


- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
Expand Down
19 changes: 14 additions & 5 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self._throughputs: dict[RunningStage, Throughput] = {}
self._t0s: dict[RunningStage, float] = {}
self._lengths: dict[RunningStage, int] = {}
self._samples: dict[RunningStage, int] = {}
self._batches: dict[RunningStage, int] = {}

@override
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
Expand All @@ -106,8 +108,13 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
def _start(self, trainer: "Trainer") -> None:
stage = trainer.state.stage
assert stage is not None
self._throughputs[stage].reset()
self._lengths[stage] = 0

if stage not in self._samples:
self._throughputs[stage].reset()
self._lengths[stage] = 0
self._samples[stage] = 0
self._batches[stage] = 0

self._t0s[stage] = time.perf_counter()

@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
Expand All @@ -133,12 +140,14 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
)
flops_per_batch = None

batch_size = self.batch_size_fn(batch)
self._samples[stage] += self.batch_size_fn(batch)
self._batches[stage] += 1

throughput.update(
time=elapsed,
batches=iter_num,
batches=self._batches[stage],
# this assumes that all iterations used the same batch size
samples=iter_num * batch_size,
samples=self._samples[stage],
lengths=None if self.length_fn is None else self._lengths[stage],
flops=flops_per_batch, # type: ignore[arg-type]
)
Expand Down
113 changes: 113 additions & 0 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn):
call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9),
call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12),
]


def test_throughput_monitor_variable_batch_size(tmp_path):
"""Test that ThroughputMonitor correctly handles variable batch sizes."""
logger_mock = Mock()
logger_mock.save_dir = tmp_path

# Simulate variable batch sizes by tracking calls
batch_sizes = [1, 3, 2, 1, 4]
call_count = [0]

def variable_batch_size_fn(batch):
# Return the predefined batch size for this call
current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)]
call_count[0] += 1
return current_batch_size

monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|")

model = BoringModel()
model.flops_per_batch = 10

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=monitor,
max_steps=len(batch_sizes),
log_every_n_steps=1,
limit_val_batches=0,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)

timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)]

with (
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
mock.patch("time.perf_counter", side_effect=timings),
):
trainer.fit(model)

log_calls = logger_mock.log_metrics.call_args_list
assert len(log_calls) == len(batch_sizes)

# Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4)
expected_cumulative_samples = [1, 4, 6, 7, 11]

for i, log_call in enumerate(log_calls):
metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0]
expected_samples = expected_cumulative_samples[i]
assert metrics["train|samples"] == expected_samples, (
f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}"
)
assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}"


def test_throughput_monitor_variable_batch_size_with_validation(tmp_path):
"""Test variable batch sizes with validation to ensure stage isolation."""
logger_mock = Mock()
logger_mock.save_dir = tmp_path

train_batch_sizes = [2, 1, 3]
val_batch_sizes = [1, 2]
train_call_count = [0]
val_call_count = [0]

def variable_batch_size_fn(batch):
if hasattr(batch, "size") and batch.size(0) > 0:
if train_call_count[0] < len(train_batch_sizes):
current_batch_size = train_batch_sizes[train_call_count[0]]
train_call_count[0] += 1
return current_batch_size
current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)]
val_call_count[0] += 1
return current_batch_size
return 1

monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3)
model = BoringModel()

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=monitor,
max_steps=len(train_batch_sizes),
log_every_n_steps=1,
limit_val_batches=2,
val_check_interval=2,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)

with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100):
trainer.fit(model)

# Verify that both training and validation metrics were logged
log_calls = logger_mock.log_metrics.call_args_list
train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)]
val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)]

assert len(train_calls) > 0, "Expected training metrics to be logged"
assert len(val_calls) > 0, "Expected validation metrics to be logged"
train_samples = []
for train_call in train_calls:
metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {})
if "train/samples" in metrics:
train_samples.append(metrics["train/samples"])
elif "train|samples" in metrics:
train_samples.append(metrics["train|samples"])
Loading