Skip to content

Commit 3bf6118

Browse files
tshu-wBorda
authored andcommitted
Make RichProgressBar visible for both light and dark background (#20260)
(cherry picked from commit 474bdd0)
1 parent 6cc2a72 commit 3bf6118

File tree

2 files changed

+6
-45
lines changed

2 files changed

+6
-45
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ class RichProgressBarTheme:
206206
207207
"""
208208

209-
description: Union[str, "Style"] = "white"
209+
description: Union[str, "Style"] = ""
210210
progress_bar: Union[str, "Style"] = "#6206E0"
211211
progress_bar_finished: Union[str, "Style"] = "#6206E0"
212212
progress_bar_pulse: Union[str, "Style"] = "#6206E0"
213-
batch_progress: Union[str, "Style"] = "white"
214-
time: Union[str, "Style"] = "grey54"
215-
processing_speed: Union[str, "Style"] = "grey70"
216-
metrics: Union[str, "Style"] = "white"
213+
batch_progress: Union[str, "Style"] = ""
214+
time: Union[str, "Style"] = "dim"
215+
processing_speed: Union[str, "Style"] = "dim underline"
216+
metrics: Union[str, "Style"] = "italic"
217217
metrics_text_delimiter: str = " "
218218
metrics_format: str = ".3f"
219219

@@ -280,7 +280,6 @@ def __init__(
280280
self._metric_component: Optional[MetricsTextColumn] = None
281281
self._progress_stopped: bool = False
282282
self.theme = theme
283-
self._update_for_light_colab_theme()
284283

285284
@property
286285
def refresh_rate(self) -> float:
@@ -318,13 +317,6 @@ def test_progress_bar(self) -> "Task":
318317
assert self.test_progress_bar_id is not None
319318
return self.progress.tasks[self.test_progress_bar_id]
320319

321-
def _update_for_light_colab_theme(self) -> None:
322-
if _detect_light_colab_theme():
323-
attributes = ["description", "batch_progress", "metrics"]
324-
for attr in attributes:
325-
if getattr(self.theme, attr) == "white":
326-
setattr(self.theme, attr, "black")
327-
328320
@override
329321
def disable(self) -> None:
330322
self._enabled = False
@@ -449,7 +441,7 @@ def on_validation_batch_start(
449441
def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID":
450442
assert self.progress is not None
451443
return self.progress.add_task(
452-
f"[{self.theme.description}]{description}",
444+
f"[{self.theme.description}]{description}" if self.theme.description else description,
453445
total=total_batches,
454446
visible=visible,
455447
)
@@ -656,20 +648,3 @@ def __getstate__(self) -> Dict:
656648
state["progress"] = None
657649
state["_console"] = None
658650
return state
659-
660-
661-
def _detect_light_colab_theme() -> bool:
662-
"""Detect if it's light theme in Colab."""
663-
try:
664-
import get_ipython
665-
except (NameError, ModuleNotFoundError):
666-
return False
667-
ipython = get_ipython()
668-
if "google.colab" in str(ipython.__class__):
669-
try:
670-
from google.colab import output
671-
672-
return output.eval_js('document.documentElement.matches("[theme=light]")')
673-
except ModuleNotFoundError:
674-
return False
675-
return False

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,20 +308,6 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmp_path):
308308
assert val_bar.total == 4
309309

310310

311-
@RunIf(rich=True)
312-
@mock.patch("lightning.pytorch.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
313-
def test_rich_progress_bar_colab_light_theme_update(*_):
314-
theme = RichProgressBar().theme
315-
assert theme.description == "black"
316-
assert theme.batch_progress == "black"
317-
assert theme.metrics == "black"
318-
319-
theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme
320-
assert theme.description == "blue"
321-
assert theme.batch_progress == "black"
322-
assert theme.metrics == "red"
323-
324-
325311
@RunIf(rich=True)
326312
def test_rich_progress_bar_metric_display_task_id(tmp_path):
327313
class CustomModel(BoringModel):

0 commit comments

Comments
 (0)