-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
pytorch-lightning/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Lines 285 to 286 in 918a1a6
if not self.train_progress_bar.disable: | |
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) |
Is this^ a safe implementation for DDP?
Based on my understanding, this will lead to the following stack trace, ultimately leading to a sync operation on ResultMetrics
if certain conditions are met. However, self.train_progress_bar.disable
is only False
(i.e. enabled) on the rank-zero process, so the metrics computations will only succeed if there happens to be another codepath on non-rank-zero devices that's requesting self._logger_connector.metrics
at the same time.
Stack trace
pytorch-lightning/src/lightning/pytorch/callbacks/progress/progress_bar.py
Lines 180 to 201 in 918a1a6
def get_metrics( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" | |
) -> dict[str, Union[int, str, float, dict[str, float]]]: | |
r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. | |
Implement this to override the items displayed in the progress bar. | |
Here is an example of how to override the defaults: | |
.. code-block:: python | |
def get_metrics(self, trainer, model): | |
# don't show the version number | |
items = super().get_metrics(trainer, model) | |
items.pop("v_num", None) | |
return items | |
Return: | |
Dictionary with the items to be displayed in the progress bar. | |
""" | |
standard_metrics = get_standard_metrics(trainer) | |
pbar_metrics = trainer.progress_bar_metrics |
pytorch-lightning/src/lightning/pytorch/trainer/trainer.py
Lines 1675 to 1683 in 918a1a6
@property | |
def progress_bar_metrics(self) -> _PBAR_DICT: | |
"""The metrics sent to the progress bar. | |
This includes metrics logged via :meth:`~lightning.pytorch.core.LightningModule.log` with the | |
:paramref:`~lightning.pytorch.core.LightningModule.log.prog_bar` argument set. | |
""" | |
return self._logger_connector.progress_bar_metrics |
pytorch-lightning/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
Lines 253 to 258 in 918a1a6
@property | |
def progress_bar_metrics(self) -> _PBAR_DICT: | |
if self.trainer._results: | |
metrics = self.metrics["pbar"] | |
self._progress_bar_metrics.update(metrics) | |
return self._progress_bar_metrics |
pytorch-lightning/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
Lines 232 to 237 in 918a1a6
@property | |
def metrics(self) -> _METRICS: | |
"""This function returns either batch or epoch metrics.""" | |
on_step = self._first_loop_iter is not None | |
assert self.trainer._results is not None | |
return self.trainer._results.metrics(on_step) |
pytorch-lightning/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
Lines 471 to 476 in 918a1a6
def metrics(self, on_step: bool) -> _METRICS: | |
metrics = _METRICS(callback={}, log={}, pbar={}) | |
for _, result_metric in self.valid_items(): | |
# extract forward_cache or computed from the _ResultMetric | |
value = self._get_cache(result_metric, on_step) |
pytorch-lightning/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
Lines 425 to 440 in 918a1a6
@staticmethod | |
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: | |
cache = None | |
if on_step and result_metric.meta.on_step: | |
cache = result_metric._forward_cache | |
elif not on_step and result_metric.meta.on_epoch: | |
if result_metric._computed is None: | |
should = result_metric.meta.sync.should | |
if not should and result_metric.is_tensor and _distributed_is_initialized(): | |
warning_cache.warn( | |
f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`" | |
" when logging on epoch level in distributed setting to accumulate the metric across" | |
" devices.", | |
category=PossibleUserWarning, | |
) | |
result_metric.compute() |
I ran into this when debugging an issue in my training script where DDP was hanging at the end of the first epoch. Changing the original code chunk as such resolved the deadlock for me.
metrics = self.get_metrics(trainer, pl_module)
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(metrics)
To provide a little more detail, I have a CustomCallback
that calls trainer.strategy.barrier()
during on_train_epoch_end
. At the end of the first training epoch, the rank-zero process hangs on TQDMProgressBar.on_train_epoch_end
and the non-rank-zero processes hang on CustomCallback.on_train_epoch_end
.