Skip to content

Commit 2f5c4b6

Browse files
littlebullGitbhimrazyBordaSkafteNickipre-commit-ci[bot]
authored
feat: Default to RichProgressBar and RichModelSummary if rich is available (#20896)
* feat: Default to RichProgressBar and RichModelSummary if rich is available Implements automatic detection of the 'rich' package and enables RichProgressBar and RichModelSummary by default in the Trainer when the package is present. This enhances the user experience with improved visual feedback without requiring manual configuration. Includes comprehensive tests for various scenarios. --------- Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2676332 commit 2f5c4b6

File tree

13 files changed

+224
-48
lines changed

13 files changed

+224
-48
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Allow returning `ONNXProgram` when calling `to_onnx(dynamo=True)` ([#20811](https://github.com/Lightning-AI/pytorch-lightning/pull/20811))
2222

2323

24+
- Default to RichProgressBar and RichModelSummary if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
25+
26+
2427
### Removed
2528

2629
-

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from datetime import timedelta
1818
from typing import Any, Optional, Union, cast
1919

20-
from lightning_utilities.core.imports import RequirementCache
20+
import torch
21+
from lightning_utilities.core.apply_func import apply_to_collection
2122
from typing_extensions import override
2223

2324
import lightning.pytorch as pl
2425
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
26+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
2527
from lightning.pytorch.utilities.types import STEP_OUTPUT
2628

27-
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
28-
2929
if _RICH_AVAILABLE:
3030
from rich import get_console, reconfigure
3131
from rich.console import Console, RenderableType
@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171171
return Text()
172172
if self._trainer.training and task.id not in self._tasks:
173173
self._tasks[task.id] = "None"
174-
if self._renderable_cache:
174+
if self._renderable_cache and self._current_task_id in self._renderable_cache:
175175
self._current_task_id = cast(TaskID, self._current_task_id)
176176
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
177177
self._current_task_id = task.id
@@ -184,8 +184,11 @@ def render(self, task: "Task") -> Text:
184184

185185
def _generate_metrics_texts(self) -> Generator[str, None, None]:
186186
for name, value in self._metrics.items():
187-
if not isinstance(value, (str, int)):
188-
value = f"{value:{self._metrics_format}}"
187+
if not isinstance(value, str):
188+
try:
189+
value = f"{value:{self._metrics_format}}"
190+
except (TypeError, ValueError):
191+
value = str(value)
189192
yield f"{name}: {value}"
190193

191194

@@ -465,17 +468,12 @@ def _initialize_train_progress_bar_id(self) -> None:
465468
self.train_progress_bar_id = self._add_task(total_batches, train_description)
466469

467470
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
468-
if self.progress is not None and self.is_enabled:
469-
assert progress_bar_id is not None
471+
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
470472
total = self.progress.tasks[progress_bar_id].total
471473
assert total is not None
472474
if not self._should_update(current, total):
473475
return
474-
475-
leftover = current % self.refresh_rate
476-
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
477-
self.progress.update(progress_bar_id, advance=advance, visible=visible)
478-
self.refresh()
476+
self.progress.update(progress_bar_id, completed=current, visible=visible)
479477

480478
def _should_update(self, current: int, total: Union[int, float]) -> bool:
481479
return current % self.refresh_rate == 0 or current == total
@@ -572,9 +570,13 @@ def on_validation_batch_end(
572570
if self.is_disabled:
573571
return
574572
if trainer.sanity_checking:
575-
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
576-
elif self.val_progress_bar_id is not None:
577-
self._update(self.val_progress_bar_id, batch_idx + 1)
573+
if self.val_sanity_progress_bar_id is not None:
574+
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
575+
return
576+
577+
if self.val_progress_bar_id is None:
578+
return
579+
self._update(self.val_progress_bar_id, batch_idx + 1)
578580
self.refresh()
579581

580582
@override
@@ -587,9 +589,8 @@ def on_test_batch_end(
587589
batch_idx: int,
588590
dataloader_idx: int = 0,
589591
) -> None:
590-
if self.is_disabled:
592+
if self.is_disabled or self.test_progress_bar_id is None:
591593
return
592-
assert self.test_progress_bar_id is not None
593594
self._update(self.test_progress_bar_id, batch_idx + 1)
594595
self.refresh()
595596

@@ -603,9 +604,8 @@ def on_predict_batch_end(
603604
batch_idx: int,
604605
dataloader_idx: int = 0,
605606
) -> None:
606-
if self.is_disabled:
607+
if self.is_disabled or self.predict_progress_bar_id is None:
607608
return
608-
assert self.predict_progress_bar_id is not None
609609
self._update(self.predict_progress_bar_id, batch_idx + 1)
610610
self.refresh()
611611

@@ -632,6 +632,14 @@ def _reset_progress_bar_ids(self) -> None:
632632
self.test_progress_bar_id = None
633633
self.predict_progress_bar_id = None
634634

635+
@override
636+
def get_metrics(
637+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
638+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
639+
items = super().get_metrics(trainer, pl_module)
640+
# convert all metrics to float before sending to rich
641+
return apply_to_collection(items, torch.Tensor, lambda x: x.item())
642+
635643
def _update_metrics(
636644
self,
637645
trainer: "pl.Trainer",

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def on_train_batch_end(
274274
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
275275
) -> None:
276276
n = batch_idx + 1
277-
if self._should_update(n, self.train_progress_bar.total):
277+
if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total):
278278
_update_n(self.train_progress_bar, n)
279279
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
280280

@@ -322,7 +322,7 @@ def on_validation_batch_end(
322322
dataloader_idx: int = 0,
323323
) -> None:
324324
n = batch_idx + 1
325-
if self._should_update(n, self.val_progress_bar.total):
325+
if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total):
326326
_update_n(self.val_progress_bar, n)
327327

328328
@override
@@ -363,7 +363,7 @@ def on_test_batch_end(
363363
dataloader_idx: int = 0,
364364
) -> None:
365365
n = batch_idx + 1
366-
if self._should_update(n, self.test_progress_bar.total):
366+
if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total):
367367
_update_n(self.test_progress_bar, n)
368368

369369
@override
@@ -402,7 +402,7 @@ def on_predict_batch_end(
402402
dataloader_idx: int = 0,
403403
) -> None:
404404
n = batch_idx + 1
405-
if self._should_update(n, self.predict_progress_bar.total):
405+
if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total):
406406
_update_n(self.predict_progress_bar, n)
407407

408408
@override

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing_extensions import override
1717

1818
from lightning.pytorch.callbacks import ModelSummary
19-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
19+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
2020
from lightning.pytorch.utilities.model_summary import get_human_readable_count
2121

2222

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import lightning.pytorch as pl
2727
from lightning.fabric.utilities.data import _set_sampler_epoch
28-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2928
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
3029
from lightning.pytorch.loops.loop import _Loop
3130
from lightning.pytorch.loops.progress import _BatchProgress
@@ -44,6 +43,7 @@
4443
from lightning.pytorch.utilities.combined_loader import CombinedLoader
4544
from lightning.pytorch.utilities.data import has_len_all_ranks
4645
from lightning.pytorch.utilities.exceptions import SIGTERMException
46+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
4747
from lightning.pytorch.utilities.model_helpers import _ModuleMode, is_overridden
4848
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4949

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.pytorch.callbacks.timer import Timer
3838
from lightning.pytorch.trainer import call
3939
from lightning.pytorch.utilities.exceptions import MisconfigurationException
40+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
4041
from lightning.pytorch.utilities.model_helpers import is_overridden
4142
from lightning.pytorch.utilities.rank_zero import rank_zero_info
4243

@@ -125,14 +126,8 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
125126
)
126127
return
127128

128-
progress_bar_callback = self.trainer.progress_bar_callback
129-
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)
130-
131129
model_summary: ModelSummary
132-
if progress_bar_callback is not None and is_progress_bar_rich:
133-
model_summary = RichModelSummary()
134-
else:
135-
model_summary = ModelSummary()
130+
model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary()
136131
self.trainer.callbacks.append(model_summary)
137132

138133
def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
@@ -157,7 +152,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
157152
)
158153

159154
if enable_progress_bar:
160-
progress_bar_callback = TQDMProgressBar()
155+
progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar()
161156
self.trainer.callbacks.append(progress_bar_callback)
162157

163158
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None:

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3232
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
33+
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
3334

3435

3536
@functools.lru_cache(maxsize=128)

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
1919
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
20-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2120
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
22-
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
21+
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _RICH_AVAILABLE
2322

2423
_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
2524

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import defaultdict
1919
from typing import Union
2020
from unittest import mock
21-
from unittest.mock import ANY, Mock, PropertyMock, call
21+
from unittest.mock import ANY, Mock, PropertyMock, call, patch
2222

2323
import pytest
2424
import torch
@@ -109,6 +109,7 @@ def test_tqdm_progress_bar_misconfiguration():
109109
Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False)
110110

111111

112+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
112113
@pytest.mark.parametrize("num_dl", [1, 2])
113114
def test_tqdm_progress_bar_totals(tmp_path, num_dl):
114115
"""Test that the progress finishes with the correct total steps processed."""
@@ -203,6 +204,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
203204
assert pbar.predict_progress_bar.leave
204205

205206

207+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
206208
def test_tqdm_progress_bar_fast_dev_run(tmp_path):
207209
model = BoringModel()
208210

@@ -323,6 +325,7 @@ def test_tqdm_progress_bar_default_value(tmp_path):
323325

324326

325327
@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
328+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
326329
def test_tqdm_progress_bar_value_on_colab(tmp_path):
327330
"""Test that Trainer will override the default in Google COLAB."""
328331
trainer = Trainer(default_root_dir=tmp_path)
@@ -411,6 +414,7 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
411414
assert progress_bar.test_progress_bar.n_values == updates
412415

413416

417+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
414418
def test_tensor_to_float_conversion(tmp_path):
415419
"""Check tensor gets converted to float."""
416420

@@ -424,7 +428,13 @@ def training_step(self, batch, batch_idx):
424428
trainer = Trainer(
425429
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
426430
)
427-
trainer.fit(TestModel())
431+
432+
with mock.patch.object(sys.stdout, "write") as mock_write:
433+
trainer.fit(TestModel())
434+
bar_updates = "".join(call.args[0] for call in mock_write.call_args_list)
435+
assert "a=0.123" in bar_updates
436+
assert "b=1.000" in bar_updates
437+
assert "c=2.000" in bar_updates
428438

429439
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
430440
assert trainer.progress_bar_metrics["b"] == 1.0
@@ -616,6 +626,7 @@ def test_progress_bar_max_val_check_interval(
616626
assert pbar_callback.is_enabled
617627

618628

629+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
619630
@RunIf(min_cuda_gpus=2, standalone=True)
620631
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
621632
def test_progress_bar_max_val_check_interval_ddp(tmp_path, val_check_interval):
@@ -703,7 +714,7 @@ def get_metrics(self, trainer, pl_module):
703714
del items["v_num"]
704715
# this is equivalent to mocking `set_postfix` as this method gets called every time
705716
self.calls[trainer.state.fn].append((
706-
trainer.state.stage,
717+
trainer.state.stage.value,
707718
trainer.current_epoch,
708719
trainer.global_step,
709720
items,

tests/tests_pytorch/callbacks/test_callbacks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from pathlib import Path
1515
from re import escape
16-
from unittest.mock import Mock
16+
from unittest.mock import Mock, patch
1717

1818
import pytest
1919
from lightning_utilities.test.warning import no_warning_call
@@ -119,6 +119,7 @@ def load_state_dict(self, state_dict) -> None:
119119
self.state = state_dict["state"]
120120

121121

122+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
122123
def test_resume_callback_state_saved_by_type_stateful(tmp_path):
123124
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using
124125
state_dict/load_state_dict."""

0 commit comments

Comments
 (0)