@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171
171
return Text ()
172
172
if self ._trainer .training and task .id not in self ._tasks :
173
173
self ._tasks [task .id ] = "None"
174
- if self ._renderable_cache :
174
+ if self ._renderable_cache and self . _current_task_id in self . _renderable_cache :
175
175
self ._current_task_id = cast (TaskID , self ._current_task_id )
176
176
self ._tasks [self ._current_task_id ] = self ._renderable_cache [self ._current_task_id ][1 ]
177
177
self ._current_task_id = task .id
@@ -185,7 +185,10 @@ def render(self, task: "Task") -> Text:
185
185
def _generate_metrics_texts (self ) -> Generator [str , None , None ]:
186
186
for name , value in self ._metrics .items ():
187
187
if not isinstance (value , str ):
188
- value = f"{ value :{self ._metrics_format }} "
188
+ try :
189
+ value = f"{ value :{self ._metrics_format }} "
190
+ except (TypeError , ValueError ):
191
+ value = str (value )
189
192
yield f"{ name } : { value } "
190
193
191
194
@@ -448,17 +451,12 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
448
451
)
449
452
450
453
def _update (self , progress_bar_id : Optional ["TaskID" ], current : int , visible : bool = True ) -> None :
451
- if self .progress is not None and self .is_enabled :
452
- assert progress_bar_id is not None
454
+ if self .progress is not None and self .is_enabled and progress_bar_id is not None :
453
455
total = self .progress .tasks [progress_bar_id ].total
454
456
assert total is not None
455
457
if not self ._should_update (current , total ):
456
458
return
457
-
458
- leftover = current % self .refresh_rate
459
- advance = leftover if (current == total and leftover != 0 ) else self .refresh_rate
460
- self .progress .update (progress_bar_id , advance = advance , visible = visible )
461
- self .refresh ()
459
+ self .progress .update (progress_bar_id , completed = current , visible = visible )
462
460
463
461
def _should_update (self , current : int , total : Union [int , float ]) -> bool :
464
462
return current % self .refresh_rate == 0 or current == total
@@ -552,9 +550,13 @@ def on_validation_batch_end(
552
550
if self .is_disabled :
553
551
return
554
552
if trainer .sanity_checking :
555
- self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 )
556
- elif self .val_progress_bar_id is not None :
557
- self ._update (self .val_progress_bar_id , batch_idx + 1 )
553
+ if self .val_sanity_progress_bar_id is not None :
554
+ self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 )
555
+ return
556
+
557
+ if self .val_progress_bar_id is None :
558
+ return
559
+ self ._update (self .val_progress_bar_id , batch_idx + 1 )
558
560
self .refresh ()
559
561
560
562
@override
@@ -567,9 +569,8 @@ def on_test_batch_end(
567
569
batch_idx : int ,
568
570
dataloader_idx : int = 0 ,
569
571
) -> None :
570
- if self .is_disabled :
572
+ if self .is_disabled or self . test_progress_bar_id is None :
571
573
return
572
- assert self .test_progress_bar_id is not None
573
574
self ._update (self .test_progress_bar_id , batch_idx + 1 )
574
575
self .refresh ()
575
576
@@ -583,9 +584,8 @@ def on_predict_batch_end(
583
584
batch_idx : int ,
584
585
dataloader_idx : int = 0 ,
585
586
) -> None :
586
- if self .is_disabled :
587
+ if self .is_disabled or self . predict_progress_bar_id is None :
587
588
return
588
- assert self .predict_progress_bar_id is not None
589
589
self ._update (self .predict_progress_bar_id , batch_idx + 1 )
590
590
self .refresh ()
591
591
0 commit comments