Skip to content

Commit 1ecfa6f

Browse files
SeanNarenSkafteNickiawaelchli
authored andcommitted
[docs] Add step to ensure sync_dist is adding to logging when multi-gpu enabled (#4817)
* Add additional check to ensure validation/test step are updated accordingly * Update docs/source/multi_gpu.rst Co-authored-by: Nicki Skafte <[email protected]> * Update docs/source/multi_gpu.rst Co-authored-by: Nicki Skafte <[email protected]> * Update docs/source/multi_gpu.rst Co-authored-by: Nicki Skafte <[email protected]> * Update docs/source/multi_gpu.rst Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 9186abe)
1 parent c8e83a1 commit 1ecfa6f

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

docs/source/multi_gpu.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,33 @@ Lightning adds the correct samplers when needed, so no need to explicitly add sa
103103

104104
.. note:: For iterable datasets, we don't do this automatically.
105105

106+
107+
Synchronize validation and test logging
108+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
109+
110+
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
111+
This is done by adding `sync_dist=True` to all `self.log` calls in the validation and test step.
112+
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
113+
114+
Note if you use any built in metrics or custom metrics that use the :ref:`Metrics API <metrics>`, these do not need to be updated and are automatically handled for you.
115+
116+
.. testcode::
117+
118+
def validation_step(self, batch, batch_idx):
119+
x, y = batch
120+
logits = self(x)
121+
loss = self.loss(logits, y)
122+
# Add sync_dist=True to sync logging across all GPU workers
123+
self.log('validation_loss', loss, on_step=True, on_epoch=True, sync_dist=True)
124+
125+
def test_step(self, batch, batch_idx):
126+
x, y = batch
127+
logits = self(x)
128+
loss = self.loss(logits, y)
129+
# Add sync_dist=True to sync logging across all GPU workers
130+
self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True)
131+
132+
106133
Make models pickleable
107134
^^^^^^^^^^^^^^^^^^^^^^
108135
It's very likely your code is already `pickleable <https://docs.python.org/3/library/pickle.html>`_,

0 commit comments

Comments
 (0)