Skip to content

Commit ce2cdda

Browse files
authored
Merge branch 'master' into feat/to_tensorrt
2 parents 817e145 + a0ce930 commit ce2cdda

File tree

16 files changed

+306
-48
lines changed

16 files changed

+306
-48
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19+
1920
### Changed
2021

2122
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
2223

2324

25+
### Fixed
26+
27+
- Fix XLA strategy to add support for `global_ordinal`, `local_ordinal`, `world_size` which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852))
28+
29+
2430
- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))
2531

2632

src/lightning/fabric/plugins/environments/xla.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def world_size(self) -> int:
6666
The output is cached for performance.
6767
6868
"""
69+
if _XLA_GREATER_EQUAL_2_1:
70+
from torch_xla import runtime as xr
71+
72+
return xr.world_size()
73+
6974
import torch_xla.core.xla_model as xm
7075

7176
return xm.xrt_world_size()
@@ -82,6 +87,11 @@ def global_rank(self) -> int:
8287
The output is cached for performance.
8388
8489
"""
90+
if _XLA_GREATER_EQUAL_2_1:
91+
from torch_xla import runtime as xr
92+
93+
return xr.global_ordinal()
94+
8595
import torch_xla.core.xla_model as xm
8696

8797
return xm.get_ordinal()
@@ -98,6 +108,11 @@ def local_rank(self) -> int:
98108
The output is cached for performance.
99109
100110
"""
111+
if _XLA_GREATER_EQUAL_2_1:
112+
from torch_xla import runtime as xr
113+
114+
return xr.local_ordinal()
115+
101116
import torch_xla.core.xla_model as xm
102117

103118
return xm.get_local_ordinal()

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Allow returning `ONNXProgram` when calling `to_onnx(dynamo=True)` ([#20811](https://github.com/Lightning-AI/pytorch-lightning/pull/20811))
2222

2323

24+
- Default to RichProgressBar and RichModelSummary if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
25+
26+
2427
### Removed
2528

2629
-

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
from datetime import timedelta
1818
from typing import Any, Optional, Union, cast
1919

20-
from lightning_utilities.core.imports import RequirementCache
20+
import torch
21+
from lightning_utilities.core.apply_func import apply_to_collection
2122
from typing_extensions import override
2223

2324
import lightning.pytorch as pl
2425
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
26+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
2527
from lightning.pytorch.utilities.types import STEP_OUTPUT
2628

27-
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
28-
2929
if _RICH_AVAILABLE:
3030
from rich import get_console, reconfigure
3131
from rich.console import Console, RenderableType
@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171171
return Text()
172172
if self._trainer.training and task.id not in self._tasks:
173173
self._tasks[task.id] = "None"
174-
if self._renderable_cache:
174+
if self._renderable_cache and self._current_task_id in self._renderable_cache:
175175
self._current_task_id = cast(TaskID, self._current_task_id)
176176
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
177177
self._current_task_id = task.id
@@ -184,8 +184,11 @@ def render(self, task: "Task") -> Text:
184184

185185
def _generate_metrics_texts(self) -> Generator[str, None, None]:
186186
for name, value in self._metrics.items():
187-
if not isinstance(value, (str, int)):
188-
value = f"{value:{self._metrics_format}}"
187+
if not isinstance(value, str):
188+
try:
189+
value = f"{value:{self._metrics_format}}"
190+
except (TypeError, ValueError):
191+
value = str(value)
189192
yield f"{name}: {value}"
190193

191194

@@ -465,17 +468,12 @@ def _initialize_train_progress_bar_id(self) -> None:
465468
self.train_progress_bar_id = self._add_task(total_batches, train_description)
466469

467470
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
468-
if self.progress is not None and self.is_enabled:
469-
assert progress_bar_id is not None
471+
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
470472
total = self.progress.tasks[progress_bar_id].total
471473
assert total is not None
472474
if not self._should_update(current, total):
473475
return
474-
475-
leftover = current % self.refresh_rate
476-
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
477-
self.progress.update(progress_bar_id, advance=advance, visible=visible)
478-
self.refresh()
476+
self.progress.update(progress_bar_id, completed=current, visible=visible)
479477

480478
def _should_update(self, current: int, total: Union[int, float]) -> bool:
481479
return current % self.refresh_rate == 0 or current == total
@@ -572,9 +570,13 @@ def on_validation_batch_end(
572570
if self.is_disabled:
573571
return
574572
if trainer.sanity_checking:
575-
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
576-
elif self.val_progress_bar_id is not None:
577-
self._update(self.val_progress_bar_id, batch_idx + 1)
573+
if self.val_sanity_progress_bar_id is not None:
574+
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
575+
return
576+
577+
if self.val_progress_bar_id is None:
578+
return
579+
self._update(self.val_progress_bar_id, batch_idx + 1)
578580
self.refresh()
579581

580582
@override
@@ -587,9 +589,8 @@ def on_test_batch_end(
587589
batch_idx: int,
588590
dataloader_idx: int = 0,
589591
) -> None:
590-
if self.is_disabled:
592+
if self.is_disabled or self.test_progress_bar_id is None:
591593
return
592-
assert self.test_progress_bar_id is not None
593594
self._update(self.test_progress_bar_id, batch_idx + 1)
594595
self.refresh()
595596

@@ -603,9 +604,8 @@ def on_predict_batch_end(
603604
batch_idx: int,
604605
dataloader_idx: int = 0,
605606
) -> None:
606-
if self.is_disabled:
607+
if self.is_disabled or self.predict_progress_bar_id is None:
607608
return
608-
assert self.predict_progress_bar_id is not None
609609
self._update(self.predict_progress_bar_id, batch_idx + 1)
610610
self.refresh()
611611

@@ -632,6 +632,14 @@ def _reset_progress_bar_ids(self) -> None:
632632
self.test_progress_bar_id = None
633633
self.predict_progress_bar_id = None
634634

635+
@override
636+
def get_metrics(
637+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
638+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
639+
items = super().get_metrics(trainer, pl_module)
640+
# convert all metrics to float before sending to rich
641+
return apply_to_collection(items, torch.Tensor, lambda x: x.item())
642+
635643
def _update_metrics(
636644
self,
637645
trainer: "pl.Trainer",

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def on_train_batch_end(
274274
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
275275
) -> None:
276276
n = batch_idx + 1
277-
if self._should_update(n, self.train_progress_bar.total):
277+
if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total):
278278
_update_n(self.train_progress_bar, n)
279279
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
280280

@@ -322,7 +322,7 @@ def on_validation_batch_end(
322322
dataloader_idx: int = 0,
323323
) -> None:
324324
n = batch_idx + 1
325-
if self._should_update(n, self.val_progress_bar.total):
325+
if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total):
326326
_update_n(self.val_progress_bar, n)
327327

328328
@override
@@ -363,7 +363,7 @@ def on_test_batch_end(
363363
dataloader_idx: int = 0,
364364
) -> None:
365365
n = batch_idx + 1
366-
if self._should_update(n, self.test_progress_bar.total):
366+
if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total):
367367
_update_n(self.test_progress_bar, n)
368368

369369
@override
@@ -402,7 +402,7 @@ def on_predict_batch_end(
402402
dataloader_idx: int = 0,
403403
) -> None:
404404
n = batch_idx + 1
405-
if self._should_update(n, self.predict_progress_bar.total):
405+
if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total):
406406
_update_n(self.predict_progress_bar, n)
407407

408408
@override

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing_extensions import override
1717

1818
from lightning.pytorch.callbacks import ModelSummary
19-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
19+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
2020
from lightning.pytorch.utilities.model_summary import get_human_readable_count
2121

2222

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import lightning.pytorch as pl
2727
from lightning.fabric.utilities.data import _set_sampler_epoch
28-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2928
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
3029
from lightning.pytorch.loops.loop import _Loop
3130
from lightning.pytorch.loops.progress import _BatchProgress
@@ -44,6 +43,7 @@
4443
from lightning.pytorch.utilities.combined_loader import CombinedLoader
4544
from lightning.pytorch.utilities.data import has_len_all_ranks
4645
from lightning.pytorch.utilities.exceptions import SIGTERMException
46+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
4747
from lightning.pytorch.utilities.model_helpers import _ModuleMode, is_overridden
4848
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4949

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning.pytorch.callbacks.timer import Timer
3838
from lightning.pytorch.trainer import call
3939
from lightning.pytorch.utilities.exceptions import MisconfigurationException
40+
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
4041
from lightning.pytorch.utilities.model_helpers import is_overridden
4142
from lightning.pytorch.utilities.rank_zero import rank_zero_info
4243

@@ -125,14 +126,8 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
125126
)
126127
return
127128

128-
progress_bar_callback = self.trainer.progress_bar_callback
129-
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)
130-
131129
model_summary: ModelSummary
132-
if progress_bar_callback is not None and is_progress_bar_rich:
133-
model_summary = RichModelSummary()
134-
else:
135-
model_summary = ModelSummary()
130+
model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary()
136131
self.trainer.callbacks.append(model_summary)
137132

138133
def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
@@ -157,7 +152,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
157152
)
158153

159154
if enable_progress_bar:
160-
progress_bar_callback = TQDMProgressBar()
155+
progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar()
161156
self.trainer.callbacks.append(progress_bar_callback)
162157

163158
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None:

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3232
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
33+
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
3334

3435

3536
@functools.lru_cache(maxsize=128)

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
1919
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
20-
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2120
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE, _TORCH_TRT_AVAILABLE
22-
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
21+
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _RICH_AVAILABLE
2322

2423
_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
2524

0 commit comments

Comments
 (0)