Skip to content

Commit b0d0c9c

Browse files
authored
Merge branch 'master' into feature/9580-rich-defaults
2 parents 905f065 + 25b1343 commit b0d0c9c

File tree

6 files changed

+95
-4
lines changed

6 files changed

+95
-4
lines changed

docs/source-pytorch/accelerators/accelerator_prepare.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Synchronize validation and test logging
7878
***************************************
7979

8080
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
81-
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step.
81+
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
8282
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
8383
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
8484

docs/source-pytorch/extensions/logging.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ The :meth:`~lightning.pytorch.core.LightningModule.log` method has a few options
137137
* ``logger``: Logs to the logger like ``Tensorboard``, or any other custom logger passed to the :class:`~lightning.pytorch.trainer.trainer.Trainer` (Default: ``True``).
138138
* ``reduce_fx``: Reduction function over step values for end of epoch. Uses :func:`torch.mean` by default and is not applied when a :class:`torchmetrics.Metric` is logged.
139139
* ``enable_graph``: If True, will not auto detach the graph.
140-
* ``sync_dist``: If True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.
140+
* ``sync_dist``: If True, averages the metric across devices. Use with care as this may lead to a significant communication overhead.
141141
* ``sync_dist_group``: The DDP group to sync across.
142142
* ``add_dataloader_idx``: If True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.
143143
* ``batch_size``: Current batch size used for accumulating logs logged with ``on_epoch=True``. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

src/lightning/pytorch/loops/loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class _Loop:
2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
2525
self._loaded_from_state_dict = False
26+
self._resuming_from_checkpoint = False
2627
self.trainer = trainer
2728

2829
@property
@@ -38,6 +39,11 @@ def restarting(self, restarting: bool) -> None:
3839
if isinstance(loop, _Loop):
3940
loop.restarting = restarting
4041

42+
@property
43+
def is_resuming(self) -> bool:
44+
"""Indicates whether training is being resumed from a checkpoint."""
45+
return self._resuming_from_checkpoint
46+
4147
def reset_restart_stage(self) -> None:
4248
pass
4349

@@ -87,6 +93,7 @@ def load_state_dict(
8793
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8894
self.restarting = True
8995
self._loaded_from_state_dict = True
96+
self._resuming_from_checkpoint = True
9097

9198
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9299
for k, v in self.__dict__.items():
@@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102109
def on_iteration_done(self) -> None:
103110
self._restarting = False
104111
self._loaded_from_state_dict = False
112+
self._resuming_from_checkpoint = False
105113
self.reset_restart_stage()

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,11 @@ def reset(self) -> None:
237237

238238
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
239239
# `iter()` was called once in `FitLoop.setup_data()` already
240-
if self.trainer.current_epoch > 0 and not self.restarting:
240+
# Call `iter()` again only when:
241+
# 1. Not restarting
242+
# 2. Not resuming from checkpoint (not is_resuming)
243+
# 3. Past first epoch (current_epoch > 0)
244+
if self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming and not self.restarting:
241245
iter(data_fetcher) # creates the iterator inside the fetcher
242246

243247
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching

src/lightning/pytorch/strategies/xla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
247247

248248
@override
249249
def reduce(
250-
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
250+
self,
251+
output: Union[Tensor, Any],
252+
group: Optional[Any] = None,
253+
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
251254
) -> Tensor:
252255
if not isinstance(output, Tensor):
253256
output = torch.tensor(output, device=self.root_device)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This test tests the resuming of training from a checkpoint file using an IterableDataset.
16+
# And contains code mentioned in the issue: #19427.
17+
# Ref: https://github.com/Lightning-AI/pytorch-lightning/issues/19427
18+
import multiprocessing as mp
19+
import os
20+
import sys
21+
from collections.abc import Iterator
22+
from pathlib import Path
23+
from queue import Queue
24+
25+
import numpy as np
26+
import pytest
27+
from torch.utils.data import DataLoader, IterableDataset
28+
29+
from lightning.pytorch import Trainer
30+
from lightning.pytorch.demos.boring_classes import BoringModel
31+
32+
33+
class QueueDataset(IterableDataset):
34+
def __init__(self, queue: Queue) -> None:
35+
super().__init__()
36+
self.queue = queue
37+
38+
def __iter__(self) -> Iterator:
39+
for _ in range(5):
40+
tensor, _ = self.queue.get(timeout=5)
41+
yield tensor
42+
43+
44+
def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> None:
45+
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None)
46+
trainer = Trainer(
47+
max_epochs=max_epochs,
48+
enable_progress_bar=False,
49+
enable_checkpointing=False,
50+
devices=1,
51+
logger=False,
52+
)
53+
if ckpt_path.exists():
54+
trainer.fit(BoringModel(), dataloader, ckpt_path=str(ckpt_path))
55+
else:
56+
trainer.fit(BoringModel(), dataloader)
57+
trainer.save_checkpoint(str(ckpt_path))
58+
59+
60+
@pytest.mark.skipif(sys.platform == "darwin", reason="Skip on macOS due to multiprocessing issues")
61+
def test_resume_training_with(tmp_path):
62+
"""Test resuming training from checkpoint file using a IterableDataset."""
63+
q = mp.Queue()
64+
arr = np.random.random([1, 32]).astype(np.float32)
65+
for idx in range(20):
66+
q.put((arr, idx))
67+
68+
max_epoch = 2
69+
ckpt_path = tmp_path / "model.ckpt"
70+
train_model(q, max_epoch, ckpt_path)
71+
72+
assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created"
73+
ckpt_size = os.path.getsize(ckpt_path)
74+
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
75+
76+
train_model(q, max_epoch + 2, ckpt_path)

0 commit comments

Comments
 (0)