@@ -254,12 +254,13 @@ def on_train_start(self, *_: Any) -> None:
254254 def on_train_epoch_start (self , trainer : "pl.Trainer" , * _ : Any ) -> None :
255255 total_batches = self .total_batches_current_epoch
256256 self .main_progress_bar .reset (convert_inf (total_batches ))
257+ self .main_progress_bar .initial = 0
257258 self .main_progress_bar .set_description (f"Epoch { trainer .current_epoch } " )
258259
259260 def on_train_batch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , * _ : Any ) -> None :
260261 current = self .train_batch_idx + self ._val_processed
261262 if self ._should_update (current , self .main_progress_bar .total ):
262- _update_n (self .main_progress_bar , current , self . refresh_rate )
263+ _update_n (self .main_progress_bar , current )
263264 self .main_progress_bar .set_postfix (self .get_metrics (trainer , pl_module ))
264265
265266 def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
@@ -280,16 +281,17 @@ def on_validation_batch_start(
280281 return
281282
282283 self .val_progress_bar .reset (convert_inf (self .total_val_batches_current_dataloader ))
284+ self .val_progress_bar .initial = 0
283285 desc = self .sanity_check_description if trainer .sanity_checking else self .validation_description
284286 self .val_progress_bar .set_description (f"{ desc } DataLoader { dataloader_idx } " )
285287
286288 def on_validation_batch_end (self , trainer : "pl.Trainer" , * _ : Any ) -> None :
287289 if self ._should_update (self .val_batch_idx , self .val_progress_bar .total ):
288- _update_n (self .val_progress_bar , self .val_batch_idx , self . refresh_rate )
290+ _update_n (self .val_progress_bar , self .val_batch_idx )
289291
290292 current = self .train_batch_idx + self ._val_processed
291293 if trainer .state .fn == "fit" and self ._should_update (current , self .main_progress_bar .total ):
292- _update_n (self .main_progress_bar , current , self . refresh_rate )
294+ _update_n (self .main_progress_bar , current )
293295
294296 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
295297 if self ._main_progress_bar is not None and trainer .state .fn == "fit" :
@@ -307,11 +309,12 @@ def on_test_batch_start(
307309 return
308310
309311 self .test_progress_bar .reset (convert_inf (self .total_test_batches_current_dataloader ))
312+ self .test_progress_bar .initial = 0
310313 self .test_progress_bar .set_description (f"{ self .test_description } DataLoader { dataloader_idx } " )
311314
312315 def on_test_batch_end (self , * _ : Any ) -> None :
313316 if self ._should_update (self .test_batch_idx , self .test_progress_bar .total ):
314- _update_n (self .test_progress_bar , self .test_batch_idx , self . refresh_rate )
317+ _update_n (self .test_progress_bar , self .test_batch_idx )
315318
316319 def on_test_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
317320 self .test_progress_bar .close ()
@@ -327,11 +330,12 @@ def on_predict_batch_start(
327330 return
328331
329332 self .predict_progress_bar .reset (convert_inf (self .total_predict_batches_current_dataloader ))
333+ self .predict_progress_bar .initial = 0
330334 self .predict_progress_bar .set_description (f"{ self .predict_description } DataLoader { dataloader_idx } " )
331335
332336 def on_predict_batch_end (self , * _ : Any ) -> None :
333337 if self ._should_update (self .predict_batch_idx , self .predict_progress_bar .total ):
334- _update_n (self .predict_progress_bar , self .predict_batch_idx , self . refresh_rate )
338+ _update_n (self .predict_progress_bar , self .predict_batch_idx )
335339
336340 def on_predict_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
337341 self .predict_progress_bar .close ()
@@ -375,9 +379,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
375379 return x
376380
377381
378- def _update_n (bar : _tqdm , current : int , refresh_rate : int ) -> None :
382+ def _update_n (bar : _tqdm , value : int ) -> None :
379383 if not bar .disable :
380- total = bar .total
381- leftover = current % refresh_rate
382- advance = leftover if (current == total and leftover != 0 ) else refresh_rate
383- bar .update (advance )
384+ bar .n = value
385+ bar .refresh ()
0 commit comments