@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171
171
return Text ()
172
172
if self ._trainer .training and task .id not in self ._tasks :
173
173
self ._tasks [task .id ] = "None"
174
- if self ._renderable_cache :
174
+ if self ._renderable_cache and self . _current_task_id in self . _renderable_cache :
175
175
self ._current_task_id = cast (TaskID , self ._current_task_id )
176
176
self ._tasks [self ._current_task_id ] = self ._renderable_cache [self ._current_task_id ][1 ]
177
177
self ._current_task_id = task .id
@@ -185,7 +185,10 @@ def render(self, task: "Task") -> Text:
185
185
def _generate_metrics_texts (self ) -> Generator [str , None , None ]:
186
186
for name , value in self ._metrics .items ():
187
187
if not isinstance (value , str ):
188
- value = f"{ value :{self ._metrics_format }} "
188
+ try :
189
+ value = f"{ value :{self ._metrics_format }} "
190
+ except (TypeError , ValueError ):
191
+ value = str (value )
189
192
yield f"{ name } : { value } "
190
193
191
194
@@ -465,17 +468,12 @@ def _initialize_train_progress_bar_id(self) -> None:
465
468
self .train_progress_bar_id = self ._add_task (total_batches , train_description )
466
469
467
470
def _update (self , progress_bar_id : Optional ["TaskID" ], current : int , visible : bool = True ) -> None :
468
- if self .progress is not None and self .is_enabled :
469
- assert progress_bar_id is not None
471
+ if self .progress is not None and self .is_enabled and progress_bar_id is not None :
470
472
total = self .progress .tasks [progress_bar_id ].total
471
473
assert total is not None
472
474
if not self ._should_update (current , total ):
473
475
return
474
-
475
- leftover = current % self .refresh_rate
476
- advance = leftover if (current == total and leftover != 0 ) else self .refresh_rate
477
- self .progress .update (progress_bar_id , advance = advance , visible = visible )
478
- self .refresh ()
476
+ self .progress .update (progress_bar_id , completed = current , visible = visible )
479
477
480
478
def _should_update (self , current : int , total : Union [int , float ]) -> bool :
481
479
return current % self .refresh_rate == 0 or current == total
@@ -572,9 +570,13 @@ def on_validation_batch_end(
572
570
if self .is_disabled :
573
571
return
574
572
if trainer .sanity_checking :
575
- self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 )
576
- elif self .val_progress_bar_id is not None :
577
- self ._update (self .val_progress_bar_id , batch_idx + 1 )
573
+ if self .val_sanity_progress_bar_id is not None :
574
+ self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 )
575
+ return
576
+
577
+ if self .val_progress_bar_id is None :
578
+ return
579
+ self ._update (self .val_progress_bar_id , batch_idx + 1 )
578
580
self .refresh ()
579
581
580
582
@override
@@ -587,9 +589,8 @@ def on_test_batch_end(
587
589
batch_idx : int ,
588
590
dataloader_idx : int = 0 ,
589
591
) -> None :
590
- if self .is_disabled :
592
+ if self .is_disabled or self . test_progress_bar_id is None :
591
593
return
592
- assert self .test_progress_bar_id is not None
593
594
self ._update (self .test_progress_bar_id , batch_idx + 1 )
594
595
self .refresh ()
595
596
@@ -603,9 +604,8 @@ def on_predict_batch_end(
603
604
batch_idx : int ,
604
605
dataloader_idx : int = 0 ,
605
606
) -> None :
606
- if self .is_disabled :
607
+ if self .is_disabled or self . predict_progress_bar_id is None :
607
608
return
608
- assert self .predict_progress_bar_id is not None
609
609
self ._update (self .predict_progress_bar_id , batch_idx + 1 )
610
610
self .refresh ()
611
611
0 commit comments