Skip to content

Commit 25b1343

Browse files
sudiptob2BordabhimrazydeependujhaCopilot
authored
Fix double iteration bug when resumed from a checkpoint. (#20775)
* Fix double iteration bug when resumed from a checkpoint. * Apply suggestions from code review * update wording in the comments. Signed-off-by: sudipto baral <[email protected]> * update test Signed-off-by: sudipto baral <[email protected]> * Add independent flag to track checkpoint resumption. Signed-off-by: sudipto baral <[email protected]> * lint Signed-off-by: sudipto baral <[email protected]> * update * Update src/lightning/pytorch/loops/training_epoch_loop.py Co-authored-by: Copilot <[email protected]> * Update .github/workflows/ci-tests-pytorch.yml * update * skip --------- Signed-off-by: sudipto baral <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Deependu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent fb2e8d3 commit 25b1343

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/lightning/pytorch/loops/loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class _Loop:
2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
2525
self._loaded_from_state_dict = False
26+
self._resuming_from_checkpoint = False
2627
self.trainer = trainer
2728

2829
@property
@@ -38,6 +39,11 @@ def restarting(self, restarting: bool) -> None:
3839
if isinstance(loop, _Loop):
3940
loop.restarting = restarting
4041

42+
@property
43+
def is_resuming(self) -> bool:
44+
"""Indicates whether training is being resumed from a checkpoint."""
45+
return self._resuming_from_checkpoint
46+
4147
def reset_restart_stage(self) -> None:
4248
pass
4349

@@ -87,6 +93,7 @@ def load_state_dict(
8793
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8894
self.restarting = True
8995
self._loaded_from_state_dict = True
96+
self._resuming_from_checkpoint = True
9097

9198
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9299
for k, v in self.__dict__.items():
@@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102109
def on_iteration_done(self) -> None:
103110
self._restarting = False
104111
self._loaded_from_state_dict = False
112+
self._resuming_from_checkpoint = False
105113
self.reset_restart_stage()

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,11 @@ def reset(self) -> None:
237237

238238
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
239239
# `iter()` was called once in `FitLoop.setup_data()` already
240-
if self.trainer.current_epoch > 0 and not self.restarting:
240+
# Call `iter()` again only when:
241+
# 1. Not restarting
242+
# 2. Not resuming from checkpoint (not is_resuming)
243+
# 3. Past first epoch (current_epoch > 0)
244+
if self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming and not self.restarting:
241245
iter(data_fetcher) # creates the iterator inside the fetcher
242246

243247
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This test tests the resuming of training from a checkpoint file using an IterableDataset.
16+
# And contains code mentioned in the issue: #19427.
17+
# Ref: https://github.com/Lightning-AI/pytorch-lightning/issues/19427
18+
import multiprocessing as mp
19+
import os
20+
import sys
21+
from collections.abc import Iterator
22+
from pathlib import Path
23+
from queue import Queue
24+
25+
import numpy as np
26+
import pytest
27+
from torch.utils.data import DataLoader, IterableDataset
28+
29+
from lightning.pytorch import Trainer
30+
from lightning.pytorch.demos.boring_classes import BoringModel
31+
32+
33+
class QueueDataset(IterableDataset):
34+
def __init__(self, queue: Queue) -> None:
35+
super().__init__()
36+
self.queue = queue
37+
38+
def __iter__(self) -> Iterator:
39+
for _ in range(5):
40+
tensor, _ = self.queue.get(timeout=5)
41+
yield tensor
42+
43+
44+
def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> None:
45+
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None)
46+
trainer = Trainer(
47+
max_epochs=max_epochs,
48+
enable_progress_bar=False,
49+
enable_checkpointing=False,
50+
devices=1,
51+
logger=False,
52+
)
53+
if ckpt_path.exists():
54+
trainer.fit(BoringModel(), dataloader, ckpt_path=str(ckpt_path))
55+
else:
56+
trainer.fit(BoringModel(), dataloader)
57+
trainer.save_checkpoint(str(ckpt_path))
58+
59+
60+
@pytest.mark.skipif(sys.platform == "darwin", reason="Skip on macOS due to multiprocessing issues")
61+
def test_resume_training_with(tmp_path):
62+
"""Test resuming training from checkpoint file using a IterableDataset."""
63+
q = mp.Queue()
64+
arr = np.random.random([1, 32]).astype(np.float32)
65+
for idx in range(20):
66+
q.put((arr, idx))
67+
68+
max_epoch = 2
69+
ckpt_path = tmp_path / "model.ckpt"
70+
train_model(q, max_epoch, ckpt_path)
71+
72+
assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created"
73+
ckpt_size = os.path.getsize(ckpt_path)
74+
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
75+
76+
train_model(q, max_epoch + 2, ckpt_path)

0 commit comments

Comments
 (0)