File tree Expand file tree Collapse file tree 2 files changed +35
-3
lines changed
src/lightning/pytorch/callbacks
tests/tests_pytorch/callbacks Expand file tree Collapse file tree 2 files changed +35
-3
lines changed Original file line number Diff line number Diff line change @@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
109109 stage = trainer .state .stage
110110 assert stage is not None
111111
112- if stage not in self ._samples :
112+ reset_needed = trainer .state .fn == TrainerFn .FITTING or stage not in self ._samples
113+
114+ if reset_needed :
113115 self ._throughputs [stage ].reset ()
114116 self ._lengths [stage ] = 0
115117 self ._samples [stage ] = 0
@@ -202,10 +204,17 @@ def on_validation_batch_end(
202204 def on_validation_end (self , trainer : "Trainer" , * _ : Any ) -> None :
203205 if trainer .sanity_checking or trainer .state .fn != TrainerFn .FITTING :
204206 return
207+
208+ train_times = self ._throughputs [RunningStage .TRAINING ]._time
209+ val_times = self ._throughputs [RunningStage .VALIDATING ]._time
210+
211+ train_elapsed = train_times [- 1 ] if train_times else 0.0
212+ val_elapsed = val_times [- 1 ] if val_times else 0.0
213+
205214 # add the validation time to the training time before continuing to avoid sinking the training throughput
206- training_finished = self ._t0s [RunningStage .TRAINING ] + sum ( self . _throughputs [ RunningStage . TRAINING ]. _time )
215+ training_finished = self ._t0s [RunningStage .TRAINING ] + train_elapsed
207216 time_between_train_and_val = self ._t0s [RunningStage .VALIDATING ] - training_finished
208- val_time = sum ( self . _throughputs [ RunningStage . VALIDATING ]. _time )
217+ val_time = val_elapsed
209218 self ._t0s [RunningStage .TRAINING ] += time_between_train_and_val + val_time
210219
211220 @override
Original file line number Diff line number Diff line change @@ -420,3 +420,26 @@ def variable_batch_size_fn(batch):
420420 train_samples .append (metrics ["train/samples" ])
421421 elif "train|samples" in metrics :
422422 train_samples .append (metrics ["train|samples" ])
423+
424+
425+ def test_throughput_monitor_validation_sum_overflow_real (tmp_path ):
426+ logger_mock = Mock ()
427+ logger_mock .save_dir = tmp_path
428+ monitor = ThroughputMonitor (batch_size_fn = lambda x : 1 )
429+ model = BoringModel ()
430+ model .flops_per_batch = 10
431+
432+ trainer = Trainer (
433+ devices = 1 ,
434+ logger = logger_mock ,
435+ callbacks = [monitor ],
436+ max_epochs = 100 ,
437+ enable_checkpointing = False ,
438+ enable_model_summary = False ,
439+ enable_progress_bar = False ,
440+ )
441+
442+ try :
443+ trainer .fit (model )
444+ except Exception as e :
445+ pytest .fail (f"ThroughputMonitor raised an unexpected exception: { e } " )
You can’t perform that action at this time.
0 commit comments