Skip to content

Commit 679392b

Browse files
awaelchlilantiga
authored andcommitted
Avoid warning when resuming mid-epoch checkpoint and using stateful dataloader (#19475)
1 parent a00a999 commit 679392b

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing_extensions import override
1919

2020
import lightning.pytorch as pl
21+
from lightning.fabric.utilities.types import _Stateful
22+
from lightning.fabric.utilities.warnings import PossibleUserWarning
2123
from lightning.pytorch import loops # import as loops to avoid circular imports
2224
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
2325
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
@@ -152,10 +154,16 @@ def reset(self) -> None:
152154
trainer = self.trainer
153155
if trainer.num_training_batches != float("inf"):
154156
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
155-
if self.global_step % expected_steps != 0:
157+
loader = trainer.fit_loop._combined_loader
158+
assert loader is not None
159+
is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened)
160+
if self.global_step % expected_steps != 0 and not is_resumable_loader:
156161
rank_zero_warn(
157-
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
158-
" results if further training is done. Consider using an end-of-epoch checkpoint"
162+
"You're resuming from a checkpoint that ended before the epoch ended and your dataloader is"
163+
" not resumable. This can cause unreliable results if further training is done."
164+
" Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing"
165+
" the `state_dict` / `load_state_dict` interface.",
166+
category=PossibleUserWarning,
159167
)
160168
else:
161169
self.batch_progress.reset_on_run()

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
from unittest.mock import Mock, patch
1616

1717
import pytest
18+
import torch
19+
from lightning.fabric.utilities.warnings import PossibleUserWarning
20+
from lightning.pytorch.callbacks import ModelCheckpoint
1821
from lightning.pytorch.demos.boring_classes import BoringModel
1922
from lightning.pytorch.trainer.trainer import Trainer
23+
from lightning_utilities.test.warning import no_warning_call
2024

2125

22-
def test_no_val_on_train_epoch_loop_restart(tmpdir):
26+
def test_no_val_on_train_epoch_loop_restart(tmp_path):
2327
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
2428
trainer_kwargs = {
2529
"max_epochs": 1,
@@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
3135
trainer = Trainer(**trainer_kwargs)
3236
model = BoringModel()
3337
trainer.fit(model)
34-
ckpt_path = str(tmpdir / "last.ckpt")
38+
ckpt_path = str(tmp_path / "last.ckpt")
3539
trainer.save_checkpoint(ckpt_path)
3640

3741
trainer_kwargs["max_epochs"] = 2
@@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs):
157161
model = MyModel()
158162
trainer.fit(model)
159163
assert model.last_batch_idx == 3
164+
165+
166+
def test_resume_mid_epoch_warning(tmp_path):
167+
"""Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful."""
168+
169+
class NotStatefulIterable:
170+
def __init__(self):
171+
self.index = 0
172+
173+
def __iter__(self):
174+
for i in range(self.index, len(self)):
175+
yield torch.ones(2, 32) * i
176+
177+
def __len__(self):
178+
return 3
179+
180+
class StatefulIterable(NotStatefulIterable):
181+
def state_dict(self):
182+
return {"index": self.index}
183+
184+
def load_state_dict(self, state_dict):
185+
self.index = state_dict["index"]
186+
187+
trainer_kwargs = {
188+
"default_root_dir": tmp_path,
189+
"accelerator": "cpu",
190+
"max_epochs": 1,
191+
"enable_model_summary": False,
192+
"enable_progress_bar": False,
193+
"logger": False,
194+
}
195+
196+
def train_and_resume(dataloader, resume_step, expected_warning):
197+
# Initial training
198+
checkpoint_dir = tmp_path / "checkpoints"
199+
trainer = Trainer(
200+
**trainer_kwargs,
201+
callbacks=ModelCheckpoint(dirpath=checkpoint_dir, every_n_train_steps=1, save_top_k=-1),
202+
)
203+
trainer.fit(BoringModel(), dataloader)
204+
205+
# Resume
206+
trainer = Trainer(**trainer_kwargs, enable_checkpointing=False)
207+
resume_from = checkpoint_dir / f"epoch=0-step={resume_step}.ckpt"
208+
warn_assert = pytest.warns if expected_warning else no_warning_call
209+
with warn_assert(PossibleUserWarning, match="resuming from a checkpoint that ended before"):
210+
trainer.fit(BoringModel(), dataloader, ckpt_path=resume_from)
211+
212+
# Resume mid-epoch, no stateful dataloader -> warning
213+
train_and_resume(dataloader=NotStatefulIterable(), resume_step=1, expected_warning=True)
214+
215+
# Resume end-of-epoch, no stateful dataloader -> no warning
216+
train_and_resume(dataloader=NotStatefulIterable(), resume_step=3, expected_warning=False)
217+
218+
# Resume mid-epoch, stateful dataloader -> no warning
219+
train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)

0 commit comments

Comments
 (0)