2525else :
2626 from tqdm import tqdm as _tqdm
2727
28+ import pytorch_lightning as pl
2829from pytorch_lightning .callbacks .progress .base import ProgressBarBase
2930
3031_PAD_SIZE = 5
@@ -206,12 +207,10 @@ def init_test_tqdm(self) -> Tqdm:
206207 return bar
207208
208209 def on_sanity_check_start (self , trainer , pl_module ):
209- super ().on_sanity_check_start (trainer , pl_module )
210210 self .val_progress_bar = self .init_sanity_tqdm ()
211211 self .main_progress_bar = Tqdm (disable = True ) # dummy progress bar
212212
213213 def on_sanity_check_end (self , trainer , pl_module ):
214- super ().on_sanity_check_end (trainer , pl_module )
215214 self .main_progress_bar .close ()
216215 self .val_progress_bar .close ()
217216
@@ -233,49 +232,59 @@ def on_train_epoch_start(self, trainer, pl_module):
233232
234233 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
235234 super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
236- total_batches = self .total_train_batches + self .total_val_batches
237- total_batches = convert_inf (total_batches )
238- if self ._should_update (self .train_batch_idx , total_batches ):
235+ if self ._should_update (self .train_batch_idx ):
239236 self ._update_bar (self .main_progress_bar )
240237 self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
241238
239+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
240+ if self .is_enabled :
241+ self ._update_bar (self .main_progress_bar )
242+ self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
243+
244+ def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
245+ self .main_progress_bar .close ()
246+
242247 def on_validation_start (self , trainer , pl_module ):
243248 super ().on_validation_start (trainer , pl_module )
244249 if trainer .sanity_checking :
245250 reset (self .val_progress_bar , total = sum (trainer .num_sanity_val_batches ), current = self .val_batch_idx )
246251 else :
247- self ._update_bar (self .main_progress_bar ) # fill up remaining
252+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
253+ self ._update_bar (self .main_progress_bar ) # fill up remaining
248254 self .val_progress_bar = self .init_validation_tqdm ()
249255 reset (self .val_progress_bar , total = self .total_val_batches , current = self .val_batch_idx )
250256
251257 def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
252258 super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
253- if self ._should_update (self .val_batch_idx , convert_inf (self .total_val_batches )):
259+ if self ._should_update (self .val_batch_idx ):
260+ self ._update_bar (self .val_progress_bar )
261+ if trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
262+ self ._update_bar (self .main_progress_bar )
263+
264+ def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
265+ if self .is_enabled :
254266 self ._update_bar (self .val_progress_bar )
255- self ._update_bar (self .main_progress_bar )
256267
257268 def on_validation_end (self , trainer , pl_module ):
258- super ().on_validation_end (trainer , pl_module )
259- if self .main_progress_bar is not None :
269+ if self .main_progress_bar is not None and trainer .state .fn == pl .trainer .states .TrainerFn .FITTING :
260270 self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
261271 self .val_progress_bar .close ()
262272
263- def on_train_end (self , trainer , pl_module ):
264- super ().on_train_end (trainer , pl_module )
265- self .main_progress_bar .close ()
266-
267273 def on_test_start (self , trainer , pl_module ):
268274 super ().on_test_start (trainer , pl_module )
269275 self .test_progress_bar = self .init_test_tqdm ()
270276 self .test_progress_bar .total = convert_inf (self .total_test_batches )
271277
272278 def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
273279 super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
274- if self ._should_update (self .test_batch_idx , self .total_test_batches ):
280+ if self ._should_update (self .test_batch_idx ):
281+ self ._update_bar (self .test_progress_bar )
282+
283+ def on_test_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
284+ if self .is_enabled :
275285 self ._update_bar (self .test_progress_bar )
276286
277287 def on_test_end (self , trainer , pl_module ):
278- super ().on_test_end (trainer , pl_module )
279288 self .test_progress_bar .close ()
280289
281290 def on_predict_epoch_start (self , trainer , pl_module ):
@@ -285,7 +294,7 @@ def on_predict_epoch_start(self, trainer, pl_module):
285294
286295 def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
287296 super ().on_predict_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
288- if self ._should_update (self .predict_batch_idx , self . total_predict_batches ):
297+ if self ._should_update (self .predict_batch_idx ):
289298 self ._update_bar (self .predict_progress_bar )
290299
291300 def on_predict_end (self , trainer , pl_module ):
@@ -309,8 +318,8 @@ def print(
309318 s = sep .join (map (str , args ))
310319 active_progress_bar .write (s , end = end , file = file , nolock = nolock )
311320
312- def _should_update (self , current , total ) -> bool :
313- return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
321+ def _should_update (self , idx : int ) -> bool :
322+ return self .is_enabled and (idx % self .refresh_rate == 0 )
314323
315324 def _update_bar (self , bar : Optional [Tqdm ]) -> None :
316325 """Updates the bar by the refresh rate without overshooting."""
0 commit comments