Skip to content

Commit 2ce0191

Browse files
committed
feat: Default to RichProgressBar and RichModelSummary if rich is available
Implements automatic detection of the 'rich' package and enables RichProgressBar and RichModelSummary by default in the Trainer when the package is present. This enhances the user experience with improved visual feedback without requiring manual configuration. Includes comprehensive tests for various scenarios. Fixes #9580
1 parent 64b2b6a commit 2ce0191

File tree

3 files changed

+139
-8
lines changed

3 files changed

+139
-8
lines changed

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.pytorch.callbacks.timer import Timer
3838
from lightning.pytorch.trainer import call
3939
from lightning.pytorch.utilities.exceptions import MisconfigurationException
40+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
4041
from lightning.pytorch.utilities.model_helpers import is_overridden
4142
from lightning.pytorch.utilities.rank_zero import rank_zero_info
4243

@@ -125,14 +126,8 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
125126
)
126127
return
127128

128-
progress_bar_callback = self.trainer.progress_bar_callback
129-
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)
130-
131129
model_summary: ModelSummary
132-
if progress_bar_callback is not None and is_progress_bar_rich:
133-
model_summary = RichModelSummary()
134-
else:
135-
model_summary = ModelSummary()
130+
model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary()
136131
self.trainer.callbacks.append(model_summary)
137132

138133
def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
@@ -157,7 +152,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
157152
)
158153

159154
if enable_progress_bar:
160-
progress_bar_callback = TQDMProgressBar()
155+
progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar()
161156
self.trainer.callbacks.append(progress_bar_callback)
162157

163158
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None:

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3030
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
31+
_RICH_AVAILABLE = RequirementCache("rich")
3132

3233

3334
@functools.lru_cache(maxsize=128)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import patch
16+
17+
from lightning.pytorch import Trainer
18+
from lightning.pytorch.callbacks import ModelSummary, ProgressBar, RichModelSummary, RichProgressBar, TQDMProgressBar
19+
20+
21+
class TestRichIntegration:
22+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
23+
def test_no_rich_defaults_tqdm_and_model_summary(self, tmp_path):
24+
trainer = Trainer(default_root_dir=tmp_path, logger=False, enable_checkpointing=False)
25+
assert any(isinstance(cb, TQDMProgressBar) for cb in trainer.callbacks)
26+
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
27+
assert not any(isinstance(cb, RichProgressBar) for cb in trainer.callbacks)
28+
assert not any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks)
29+
30+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
31+
def test_no_rich_respects_user_provided_tqdm_progress_bar(self, tmp_path):
32+
user_progress_bar = TQDMProgressBar()
33+
trainer = Trainer(
34+
default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False
35+
)
36+
assert user_progress_bar in trainer.callbacks
37+
assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1
38+
39+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
40+
def test_no_rich_respects_user_provided_rich_progress_bar(self, tmp_path):
41+
# If user explicitly provides RichProgressBar, it should be used,
42+
# even if _RICH_AVAILABLE is False (simulating our connector logic).
43+
# RequirementCache would normally prevent RichProgressBar instantiation if rich is truly not installed.
44+
user_progress_bar = RichProgressBar()
45+
trainer = Trainer(
46+
default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False
47+
)
48+
assert user_progress_bar in trainer.callbacks
49+
assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1
50+
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
51+
52+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
53+
def test_no_rich_respects_user_provided_model_summary(self, tmp_path):
54+
user_model_summary = ModelSummary()
55+
trainer = Trainer(
56+
default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False
57+
)
58+
assert user_model_summary in trainer.callbacks
59+
assert sum(isinstance(cb, ModelSummary) for cb in trainer.callbacks) == 1
60+
# Check that the specific instance is the one from the trainer's list of ModelSummary callbacks
61+
assert trainer.callbacks[trainer.callbacks.index(user_model_summary)] == user_model_summary
62+
63+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
64+
def test_no_rich_respects_user_provided_rich_model_summary(self, tmp_path):
65+
user_model_summary = RichModelSummary()
66+
trainer = Trainer(
67+
default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False
68+
)
69+
assert user_model_summary in trainer.callbacks
70+
assert sum(isinstance(cb, ModelSummary) for cb in trainer.callbacks) == 1
71+
# Check that the specific instance is the one from the trainer's list of ModelSummary callbacks
72+
model_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)]
73+
assert user_model_summary in model_summary_callbacks
74+
assert isinstance(model_summary_callbacks[0], RichModelSummary)
75+
76+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
77+
def test_rich_available_defaults_rich_progress_and_summary(self, tmp_path):
78+
trainer = Trainer(default_root_dir=tmp_path, logger=False, enable_checkpointing=False)
79+
assert any(isinstance(cb, RichProgressBar) for cb in trainer.callbacks)
80+
assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks)
81+
assert not any(isinstance(cb, TQDMProgressBar) for cb in trainer.callbacks)
82+
# Ensure the only ModelSummary is the RichModelSummary
83+
model_summaries = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)]
84+
assert len(model_summaries) == 1
85+
assert isinstance(model_summaries[0], RichModelSummary)
86+
87+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
88+
def test_rich_available_respects_user_tqdm_progress_bar(self, tmp_path):
89+
user_progress_bar = TQDMProgressBar()
90+
trainer = Trainer(
91+
default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False
92+
)
93+
assert user_progress_bar in trainer.callbacks
94+
assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1
95+
assert isinstance(trainer.progress_bar_callback, TQDMProgressBar)
96+
97+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
98+
def test_rich_available_respects_user_model_summary(self, tmp_path):
99+
user_model_summary = ModelSummary() # Non-rich
100+
trainer = Trainer(
101+
default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False
102+
)
103+
assert user_model_summary in trainer.callbacks
104+
model_summaries = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)]
105+
assert len(model_summaries) == 1
106+
assert isinstance(model_summaries[0], ModelSummary)
107+
assert not isinstance(model_summaries[0], RichModelSummary)
108+
109+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
110+
def test_progress_bar_disabled_no_rich(self, tmp_path):
111+
trainer = Trainer(
112+
default_root_dir=tmp_path, enable_progress_bar=False, logger=False, enable_checkpointing=False
113+
)
114+
assert not any(isinstance(cb, ProgressBar) for cb in trainer.callbacks)
115+
116+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
117+
def test_progress_bar_disabled_with_rich(self, tmp_path):
118+
trainer = Trainer(
119+
default_root_dir=tmp_path, enable_progress_bar=False, logger=False, enable_checkpointing=False
120+
)
121+
assert not any(isinstance(cb, ProgressBar) for cb in trainer.callbacks)
122+
123+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
124+
def test_model_summary_disabled_no_rich(self, tmp_path):
125+
trainer = Trainer(
126+
default_root_dir=tmp_path, enable_model_summary=False, logger=False, enable_checkpointing=False
127+
)
128+
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
129+
130+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True)
131+
def test_model_summary_disabled_with_rich(self, tmp_path):
132+
trainer = Trainer(
133+
default_root_dir=tmp_path, enable_model_summary=False, logger=False, enable_checkpointing=False
134+
)
135+
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)

0 commit comments

Comments
 (0)