Skip to content

Commit c0a63b3

Browse files
authored
Merge branch 'master' into patch-3
2 parents 78d4037 + b554e99 commit c0a63b3

File tree

3 files changed

+77
-3
lines changed

3 files changed

+77
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))
6262

6363

64+
- Fixed how `ThroughputMonitor` calculated training time ([#21291](https://github.com/Lightning-AI/pytorch-lightning/pull/21291))
65+
66+
6467
---
6568

6669
## [2.5.5] - 2025-09-05

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
109109
stage = trainer.state.stage
110110
assert stage is not None
111111

112-
if stage not in self._samples:
112+
reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples
113+
114+
if reset_needed:
113115
self._throughputs[stage].reset()
114116
self._lengths[stage] = 0
115117
self._samples[stage] = 0
@@ -202,10 +204,17 @@ def on_validation_batch_end(
202204
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
203205
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
204206
return
207+
208+
train_times = self._throughputs[RunningStage.TRAINING]._time
209+
val_times = self._throughputs[RunningStage.VALIDATING]._time
210+
211+
train_elapsed = train_times[-1] if train_times else 0.0
212+
val_elapsed = val_times[-1] if val_times else 0.0
213+
205214
# add the validation time to the training time before continuing to avoid sinking the training throughput
206-
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
215+
training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed
207216
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
208-
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
217+
val_time = val_elapsed
209218
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
210219

211220
@override

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,65 @@ def variable_batch_size_fn(batch):
420420
train_samples.append(metrics["train/samples"])
421421
elif "train|samples" in metrics:
422422
train_samples.append(metrics["train|samples"])
423+
424+
425+
def test_throughput_monitor_validation_with_many_epochs(tmp_path):
426+
"""Ensure ThroughputMonitor handles many epochs with validation and time increases monotonically."""
427+
428+
logger_mock = Mock()
429+
logger_mock.save_dir = tmp_path
430+
monitor = ThroughputMonitor(batch_size_fn=lambda x: 1)
431+
model = BoringModel()
432+
model.flops_per_batch = 10
433+
num_epochs = 100
434+
435+
trainer = Trainer(
436+
devices=1,
437+
logger=logger_mock,
438+
callbacks=[monitor],
439+
max_epochs=num_epochs,
440+
limit_train_batches=2,
441+
limit_val_batches=1,
442+
log_every_n_steps=1,
443+
enable_checkpointing=False,
444+
enable_model_summary=False,
445+
enable_progress_bar=False,
446+
)
447+
448+
timings = []
449+
t = 0.0
450+
for _ in range(num_epochs):
451+
timings += [
452+
t, # train batch 1 start
453+
t + 3.0, # train batch 1 end and start batch 2
454+
t + 6.0, # train batch 2 end
455+
t + 7.0, # val start
456+
t + 8.0, # val end
457+
]
458+
t += 10.0
459+
460+
with mock.patch("time.perf_counter", side_effect=timings):
461+
try:
462+
trainer.fit(model)
463+
except Exception as e:
464+
pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}")
465+
466+
start_train_timings_idx, end_train_timings_idx = 0, 1
467+
batch_num = 1
468+
cur_train = timings[end_train_timings_idx] - timings[start_train_timings_idx]
469+
for c in logger_mock.log_metrics.mock_calls:
470+
metrics = getattr(c, "kwargs", None) or {}
471+
metrics = metrics.get("metrics", metrics)
472+
for k, v in metrics.items():
473+
if k.endswith("train/time"):
474+
assert v == cur_train, f"Expected train/time {cur_train}, got {v}"
475+
if batch_num == 1:
476+
start_train_timings_idx += 1
477+
end_train_timings_idx += 1
478+
batch_num = 2
479+
else:
480+
start_train_timings_idx += 3
481+
end_train_timings_idx += 3
482+
batch_num = 1
483+
if end_train_timings_idx < len(timings):
484+
cur_train += timings[end_train_timings_idx] - timings[start_train_timings_idx]

0 commit comments

Comments
 (0)