1515from unittest .mock import Mock , patch
1616
1717import pytest
18+ import torch
19+ from lightning .fabric .utilities .warnings import PossibleUserWarning
20+ from lightning .pytorch .callbacks import ModelCheckpoint
1821from lightning .pytorch .demos .boring_classes import BoringModel
1922from lightning .pytorch .trainer .trainer import Trainer
23+ from lightning_utilities .test .warning import no_warning_call
2024
2125
22- def test_no_val_on_train_epoch_loop_restart (tmpdir ):
26+ def test_no_val_on_train_epoch_loop_restart (tmp_path ):
2327 """Test that training validation loop doesn't get triggered at the beginning of a restart."""
2428 trainer_kwargs = {
2529 "max_epochs" : 1 ,
@@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
3135 trainer = Trainer (** trainer_kwargs )
3236 model = BoringModel ()
3337 trainer .fit (model )
34- ckpt_path = str (tmpdir / "last.ckpt" )
38+ ckpt_path = str (tmp_path / "last.ckpt" )
3539 trainer .save_checkpoint (ckpt_path )
3640
3741 trainer_kwargs ["max_epochs" ] = 2
@@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs):
157161 model = MyModel ()
158162 trainer .fit (model )
159163 assert model .last_batch_idx == 3
164+
165+
166+ def test_resume_mid_epoch_warning (tmp_path ):
167+ """Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful."""
168+
169+ class NotStatefulIterable :
170+ def __init__ (self ):
171+ self .index = 0
172+
173+ def __iter__ (self ):
174+ for i in range (self .index , len (self )):
175+ yield torch .ones (2 , 32 ) * i
176+
177+ def __len__ (self ):
178+ return 3
179+
180+ class StatefulIterable (NotStatefulIterable ):
181+ def state_dict (self ):
182+ return {"index" : self .index }
183+
184+ def load_state_dict (self , state_dict ):
185+ self .index = state_dict ["index" ]
186+
187+ trainer_kwargs = {
188+ "default_root_dir" : tmp_path ,
189+ "accelerator" : "cpu" ,
190+ "max_epochs" : 1 ,
191+ "enable_model_summary" : False ,
192+ "enable_progress_bar" : False ,
193+ "logger" : False ,
194+ }
195+
196+ def train_and_resume (dataloader , resume_step , expected_warning ):
197+ # Initial training
198+ checkpoint_dir = tmp_path / "checkpoints"
199+ trainer = Trainer (
200+ ** trainer_kwargs ,
201+ callbacks = ModelCheckpoint (dirpath = checkpoint_dir , every_n_train_steps = 1 , save_top_k = - 1 ),
202+ )
203+ trainer .fit (BoringModel (), dataloader )
204+
205+ # Resume
206+ trainer = Trainer (** trainer_kwargs , enable_checkpointing = False )
207+ resume_from = checkpoint_dir / f"epoch=0-step={ resume_step } .ckpt"
208+ warn_assert = pytest .warns if expected_warning else no_warning_call
209+ with warn_assert (PossibleUserWarning , match = "resuming from a checkpoint that ended before" ):
210+ trainer .fit (BoringModel (), dataloader , ckpt_path = resume_from )
211+
212+ # Resume mid-epoch, no stateful dataloader -> warning
213+ train_and_resume (dataloader = NotStatefulIterable (), resume_step = 1 , expected_warning = True )
214+
215+ # Resume end-of-epoch, no stateful dataloader -> no warning
216+ train_and_resume (dataloader = NotStatefulIterable (), resume_step = 3 , expected_warning = False )
217+
218+ # Resume mid-epoch, stateful dataloader -> no warning
219+ train_and_resume (dataloader = StatefulIterable (), resume_step = 1 , expected_warning = False )
0 commit comments