Skip to content

Commit 4b19400

Browse files
author
itzhaks
committed
fix: calculating training time by summing all differences instead of taking the last calculation
1 parent feb8fa1 commit 4b19400

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
109109
stage = trainer.state.stage
110110
assert stage is not None
111111

112-
if stage not in self._samples:
112+
reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples
113+
114+
if reset_needed:
113115
self._throughputs[stage].reset()
114116
self._lengths[stage] = 0
115117
self._samples[stage] = 0
@@ -202,10 +204,17 @@ def on_validation_batch_end(
202204
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
203205
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
204206
return
207+
208+
train_times = self._throughputs[RunningStage.TRAINING]._time
209+
val_times = self._throughputs[RunningStage.VALIDATING]._time
210+
211+
train_elapsed = train_times[-1] if train_times else 0.0
212+
val_elapsed = val_times[-1] if val_times else 0.0
213+
205214
# add the validation time to the training time before continuing to avoid sinking the training throughput
206-
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
215+
training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed
207216
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
208-
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
217+
val_time = val_elapsed
209218
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
210219

211220
@override

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,26 @@ def variable_batch_size_fn(batch):
420420
train_samples.append(metrics["train/samples"])
421421
elif "train|samples" in metrics:
422422
train_samples.append(metrics["train|samples"])
423+
424+
425+
def test_throughput_monitor_validation_sum_overflow_real(tmp_path):
426+
logger_mock = Mock()
427+
logger_mock.save_dir = tmp_path
428+
monitor = ThroughputMonitor(batch_size_fn=lambda x: 1)
429+
model = BoringModel()
430+
model.flops_per_batch = 10
431+
432+
trainer = Trainer(
433+
devices=1,
434+
logger=logger_mock,
435+
callbacks=[monitor],
436+
max_epochs=100,
437+
enable_checkpointing=False,
438+
enable_model_summary=False,
439+
enable_progress_bar=False,
440+
)
441+
442+
try:
443+
trainer.fit(model)
444+
except Exception as e:
445+
pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}")

0 commit comments

Comments
 (0)