Skip to content

Commit fb2e8d3

Browse files
SkafteNickipre-commit-ci[bot]bhimrazydeependujha
authored
Sync dist clarification and consistency (#21012)
* change default reduction in xla to mean * clarify docs that sync_dist applies average * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Deependu <[email protected]>
1 parent 95ad9c2 commit fb2e8d3

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

docs/source-pytorch/accelerators/accelerator_prepare.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Synchronize validation and test logging
7878
***************************************
7979

8080
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
81-
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step.
81+
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
8282
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
8383
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
8484

docs/source-pytorch/extensions/logging.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ The :meth:`~lightning.pytorch.core.LightningModule.log` method has a few options
137137
* ``logger``: Logs to the logger like ``Tensorboard``, or any other custom logger passed to the :class:`~lightning.pytorch.trainer.trainer.Trainer` (Default: ``True``).
138138
* ``reduce_fx``: Reduction function over step values for end of epoch. Uses :func:`torch.mean` by default and is not applied when a :class:`torchmetrics.Metric` is logged.
139139
* ``enable_graph``: If True, will not auto detach the graph.
140-
* ``sync_dist``: If True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.
140+
* ``sync_dist``: If True, averages the metric across devices. Use with care as this may lead to a significant communication overhead.
141141
* ``sync_dist_group``: The DDP group to sync across.
142142
* ``add_dataloader_idx``: If True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.
143143
* ``batch_size``: Current batch size used for accumulating logs logged with ``on_epoch=True``. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

src/lightning/pytorch/strategies/xla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
247247

248248
@override
249249
def reduce(
250-
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
250+
self,
251+
output: Union[Tensor, Any],
252+
group: Optional[Any] = None,
253+
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
251254
) -> Tensor:
252255
if not isinstance(output, Tensor):
253256
output = torch.tensor(output, device=self.root_device)

0 commit comments

Comments
 (0)