Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
stage = trainer.state.stage
assert stage is not None

if stage not in self._samples:
reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples

if reset_needed:
self._throughputs[stage].reset()
self._lengths[stage] = 0
self._samples[stage] = 0
Expand Down Expand Up @@ -202,10 +204,17 @@ def on_validation_batch_end(
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
return

train_times = self._throughputs[RunningStage.TRAINING]._time
val_times = self._throughputs[RunningStage.VALIDATING]._time

train_elapsed = train_times[-1] if train_times else 0.0
val_elapsed = val_times[-1] if val_times else 0.0

# add the validation time to the training time before continuing to avoid sinking the training throughput
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
val_time = val_elapsed
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time

@override
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,26 @@ def variable_batch_size_fn(batch):
train_samples.append(metrics["train/samples"])
elif "train|samples" in metrics:
train_samples.append(metrics["train|samples"])


def test_throughput_monitor_validation_sum_overflow_real(tmp_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way for us to check that things works as expected? maybe by mocking the timings to make sure the throughput is as expected

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki I edited the test that will include checking the times even when running a large number of epochs.

logger_mock = Mock()
logger_mock.save_dir = tmp_path
monitor = ThroughputMonitor(batch_size_fn=lambda x: 1)
model = BoringModel()
model.flops_per_batch = 10

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=[monitor],
max_epochs=100,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)

try:
trainer.fit(model)
except Exception as e:
pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}")
Loading