@@ -129,11 +129,12 @@ def render(self, task) -> RenderableType:
129129 class MetricsTextColumn (ProgressColumn ):
130130 """A column containing text."""
131131
132- def __init__ (self , trainer ):
132+ def __init__ (self , trainer , style ):
133133 self ._trainer = trainer
134134 self ._tasks = {}
135135 self ._current_task_id = 0
136136 self ._metrics = {}
137+ self ._style = style
137138 super ().__init__ ()
138139
139140 def update (self , metrics ):
@@ -158,23 +159,34 @@ def render(self, task) -> Text:
158159
159160 for k , v in self ._metrics .items ():
160161 _text += f"{ k } : { round (v , 3 ) if isinstance (v , float ) else v } "
161- return Text (_text , justify = "left" )
162+ return Text (_text , justify = "left" , style = self . _style )
162163
163164
164165@dataclass
165166class RichProgressBarTheme :
166167 """Styles to associate to different base components.
167168
169+ Args:
170+ description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
171+ progress_bar: Style for the bar in progress.
172+ progress_bar_finished: Style for the finished progress bar.
173+ progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
174+ batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
175+ time: Style for the processed time and estimate time remaining.
176+ processing_speed: Style for the speed of the batches being processed.
177+ metrics: Style for the metrics
178+
168179 https://rich.readthedocs.io/en/stable/style.html
169180 """
170181
171- text_color : str = "white"
172- progress_bar_complete : Union [str , Style ] = "#6206E0"
182+ description : Union [ str , Style ] = "white"
183+ progress_bar : Union [str , Style ] = "#6206E0"
173184 progress_bar_finished : Union [str , Style ] = "#6206E0"
174185 progress_bar_pulse : Union [str , Style ] = "#6206E0"
175- batch_process : str = "white"
176- time : str = "grey54"
177- processing_speed : str = "grey70"
186+ batch_progress : Union [str , Style ] = "white"
187+ time : Union [str , Style ] = "grey54"
188+ processing_speed : Union [str , Style ] = "grey70"
189+ metrics : Union [str , Style ] = "white"
178190
179191
180192class RichProgressBar (ProgressBarBase ):
@@ -268,7 +280,7 @@ def _init_progress(self, trainer):
268280 self ._reset_progress_bar_ids ()
269281 self ._console : Console = Console ()
270282 self ._console .clear_live ()
271- self ._metric_component = MetricsTextColumn (trainer )
283+ self ._metric_component = MetricsTextColumn (trainer , self . theme . metrics )
272284 self .progress = CustomProgress (
273285 * self .configure_columns (trainer ),
274286 self ._metric_component ,
@@ -351,7 +363,7 @@ def on_validation_epoch_start(self, trainer, pl_module):
351363 def _add_task (self , total_batches : int , description : str , visible : bool = True ) -> Optional [int ]:
352364 if self .progress is not None :
353365 return self .progress .add_task (
354- f"[{ self .theme .text_color } ]{ description } " , total = total_batches , visible = visible
366+ f"[{ self .theme .description } ]{ description } " , total = total_batches , visible = visible
355367 )
356368
357369 def _update (self , progress_bar_id : int , visible : bool = True ) -> None :
@@ -448,11 +460,11 @@ def configure_columns(self, trainer) -> list:
448460 return [
449461 TextColumn ("[progress.description]{task.description}" ),
450462 CustomBarColumn (
451- complete_style = self .theme .progress_bar_complete ,
463+ complete_style = self .theme .progress_bar ,
452464 finished_style = self .theme .progress_bar_finished ,
453465 pulse_style = self .theme .progress_bar_pulse ,
454466 ),
455- BatchesProcessedColumn (style = self .theme .batch_process ),
467+ BatchesProcessedColumn (style = self .theme .batch_progress ),
456468 CustomTimeColumn (style = self .theme .time ),
457469 ProcessingSpeedColumn (style = self .theme .processing_speed ),
458470 ]
0 commit comments