Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mypy==1.5.1
torch==2.1.0
colored
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manascb1344 We can't add a dependency to Lightning like this. The rich package is optional for Lightning. Is there a different way to detect the color?

The progress bar will now adjust its color depending on whether the terminal supports color and if it's a truecolor terminal.

Could you check whether this feature is already supported in rich itself? It feels like this is probably something that the package already handles, we just haven't configured it properly. I'm not familiar with this unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, isn't purple suitable both for light and dark?


types-Markdown
types-PyYAML
Expand Down
40 changes: 34 additions & 6 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import timedelta
from typing import Any, Dict, Generator, Optional, Union, cast

import colored
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

Expand Down Expand Up @@ -221,6 +222,17 @@ class RichProgressBarTheme:
metrics_format: str = ".3f"


def detect_color_theme():
"""Detect the color theme of the terminal."""
if colored.supports_color():
if colored.detect_color() == "truecolor":
return "dark"
else:
return "light"
else:
return "unknown"


class RichProgressBar(ProgressBar):
"""Create a progress bar with `rich text formatting <https://github.com/Textualize/rich>`_.

Expand Down Expand Up @@ -284,6 +296,8 @@ def __init__(
self._progress_stopped: bool = False
self.theme = theme
self._update_for_light_colab_theme()
self._color_theme = detect_color_theme()
self._update_for_light_colab_theme()

@property
def refresh_rate(self) -> float:
Expand Down Expand Up @@ -641,16 +655,30 @@ def on_exception(
self._stop_progress()

def configure_columns(self, trainer: "pl.Trainer") -> list:
# Modify the color of progress bar based on the detected color theme
if self._color_theme == "dark":
theme = RichProgressBarTheme(
progress_bar="green",
progress_bar_finished="green",
progress_bar_pulse="green",
)
else:
theme = RichProgressBarTheme(
progress_bar="blue",
progress_bar_finished="blue",
progress_bar_pulse="blue",
)

return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
complete_style=self.theme.progress_bar,
finished_style=self.theme.progress_bar_finished,
pulse_style=self.theme.progress_bar_pulse,
complete_style=theme.progress_bar,
finished_style=theme.progress_bar_finished,
pulse_style=theme.progress_bar_pulse,
),
BatchesProcessedColumn(style=self.theme.batch_progress),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
BatchesProcessedColumn(style=theme.batch_progress),
CustomTimeColumn(style=theme.time),
ProcessingSpeedColumn(style=theme.processing_speed),
]

def __getstate__(self) -> Dict:
Expand Down