diff --git a/requirements/typing.txt b/requirements/typing.txt index e8a2baaf8713a..ad243944198ec 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,6 @@ mypy==1.8.0 torch==2.2.0 +colored types-Markdown types-PyYAML diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 4fcca7d0b0e65..fddb3469c3c8f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -16,6 +16,7 @@ from datetime import timedelta from typing import Any, Dict, Generator, Optional, Union, cast +import colored from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -221,6 +222,15 @@ class RichProgressBarTheme: metrics_format: str = ".3f" +def detect_color_theme(): + """Detect the color theme of the terminal.""" + if colored.supports_color(): + if colored.detect_color() == "truecolor": + return "dark" + return "light" + return "unknown" + + class RichProgressBar(ProgressBar): """Create a progress bar with `rich text formatting `_. @@ -284,6 +294,8 @@ def __init__( self._progress_stopped: bool = False self.theme = theme self._update_for_light_colab_theme() + self._color_theme = detect_color_theme() + self._update_for_light_colab_theme() @property def refresh_rate(self) -> float: @@ -641,16 +653,30 @@ def on_exception( self._stop_progress() def configure_columns(self, trainer: "pl.Trainer") -> list: + # Modify the color of progress bar based on the detected color theme + if self._color_theme == "dark": + theme = RichProgressBarTheme( + progress_bar="green", + progress_bar_finished="green", + progress_bar_pulse="green", + ) + else: + theme = RichProgressBarTheme( + progress_bar="blue", + progress_bar_finished="blue", + progress_bar_pulse="blue", + ) + return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( - complete_style=self.theme.progress_bar, - finished_style=self.theme.progress_bar_finished, - pulse_style=self.theme.progress_bar_pulse, + complete_style=theme.progress_bar, + finished_style=theme.progress_bar_finished, + pulse_style=theme.progress_bar_pulse, ), - BatchesProcessedColumn(style=self.theme.batch_progress), - CustomTimeColumn(style=self.theme.time), - ProcessingSpeedColumn(style=self.theme.processing_speed), + BatchesProcessedColumn(style=theme.batch_progress), + CustomTimeColumn(style=theme.time), + ProcessingSpeedColumn(style=theme.processing_speed), ] def __getstate__(self) -> Dict: