Skip to content

Commit cbb9fb5

Browse files
committed
Fix validation loop handling on restart
1 parent 0750b2e commit cbb9fb5

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def setup_data(self) -> None:
201201
def restarting_on_evaluation_end(self) -> bool:
202202
return (
203203
self.restarting
204-
and self.batch.progress.total.started == self.batch_progress.total.ready
204+
and self.batch_progress.total.started == self.batch_progress.total.ready
205205
and self.batch_progress.total.processed == self.batch_progress.total.started
206206
and self.batch_progress.total.completed == self.batch_progress.total.processed - 1
207207
)
@@ -245,6 +245,14 @@ def reset(self) -> None:
245245
data_fetcher._stop_profiler = self._on_after_fetch
246246
self._data_fetcher = data_fetcher
247247

248+
def increment_progress_to_evaluation_end(self) -> None:
249+
self.setup_data()
250+
if self.skip:
251+
return
252+
self.reset()
253+
max_batch = max(self.max_batches)
254+
self.batch_progress.increment_by(max_batch, True)
255+
248256
def on_run_start(self) -> None:
249257
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
250258
hooks."""

src/lightning/pytorch/loops/progress.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def reset(self) -> None:
5959
self.ready = 0
6060
self.completed = 0
6161

62+
@override
6263
def reset_on_restart(self) -> None:
6364
"""Reset the progress on restart.
6465
@@ -68,6 +69,11 @@ def reset_on_restart(self) -> None:
6869
"""
6970
self.ready = self.completed
7071

72+
@override
73+
def increment_by(self, n) -> None:
74+
self.ready += n
75+
self.completed += n
76+
7177

7278
@dataclass
7379
class _StartedTracker(_ReadyCompletedTracker):
@@ -94,6 +100,11 @@ def reset_on_restart(self) -> None:
94100
super().reset_on_restart()
95101
self.started = self.completed
96102

103+
@override
104+
def increment_by(self, n) -> None:
105+
super().increment_by(n)
106+
self.started += n
107+
97108

98109
@dataclass
99110
class _ProcessedTracker(_StartedTracker):
@@ -121,6 +132,11 @@ def reset_on_restart(self) -> None:
121132
super().reset_on_restart()
122133
self.processed = self.completed
123134

135+
@override
136+
def increment_by(self, n) -> None:
137+
super().increment_by(n)
138+
self.processed += n
139+
124140

125141
@dataclass
126142
class _Progress(_BaseProgress):
@@ -175,6 +191,11 @@ def reset_on_run(self) -> None:
175191
def reset_on_restart(self) -> None:
176192
self.current.reset_on_restart()
177193

194+
@override
195+
def increment_by(self, n) -> None:
196+
self.total.increment_by(n)
197+
self.current.increment_by(n)
198+
178199
@override
179200
def load_state_dict(self, state_dict: dict) -> None:
180201
self.total.load_state_dict(state_dict["total"])
@@ -206,6 +227,10 @@ def reset_on_run(self) -> None:
206227
super().reset_on_run()
207228
self.is_last_batch = False
208229

230+
def increment_by(self, n, is_last_batch=False) -> None:
231+
super().increment_by(n)
232+
self.is_last_batch = is_last_batch
233+
209234
@override
210235
def load_state_dict(self, state_dict: dict) -> None:
211236
super().load_state_dict(state_dict)

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
217217
218218
"""
219219
if self.restarting and self._should_check_val_fx(data_fetcher):
220-
# skip training and run validation in `on_advance_end`
221-
return
220+
# fast forward progress counters to end of validation
221+
self.val_loop.increment_progress_to_evaluation_end()
222+
222223
# we are going to train first so the val loop does not need to restart
223224
self.val_loop.restarting = False
224225

tests/tests_pytorch/loops/test_loops.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,71 @@ def test_restart_parity(tmp_path):
730730
assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {}
731731

732732

733+
def test_restart_parity_with_val(tmp_path):
734+
model = PredictableBoringModel()
735+
checkpoint_callback = ModelCheckpoint(
736+
dirpath=tmp_path,
737+
every_n_train_steps=2,
738+
save_top_k=-1,
739+
)
740+
trainer = Trainer(
741+
default_root_dir=tmp_path,
742+
limit_train_batches=4,
743+
max_epochs=4,
744+
callbacks=[checkpoint_callback],
745+
logger=False,
746+
enable_model_summary=False,
747+
enable_progress_bar=False,
748+
limit_val_batches=4,
749+
val_check_interval=2,
750+
)
751+
trainer.fit(model)
752+
loss = model.last_loss
753+
754+
trainer = Trainer(
755+
default_root_dir=tmp_path,
756+
limit_train_batches=4,
757+
max_epochs=4,
758+
callbacks=[checkpoint_callback],
759+
logger=False,
760+
enable_model_summary=False,
761+
enable_progress_bar=False,
762+
limit_val_batches=4,
763+
val_check_interval=2,
764+
)
765+
trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt"))
766+
loss_v1 = model.last_loss
767+
768+
assert(abs(loss - loss_v1) < 1e-8)
769+
770+
end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True)
771+
end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True)
772+
773+
assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {}
774+
assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {}
775+
assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"]
776+
assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"]
777+
assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {}
778+
779+
mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True)
780+
mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True)
781+
782+
assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {}
783+
assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {}
784+
assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"]
785+
assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"]
786+
assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {}
787+
788+
end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True)
789+
end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True)
790+
791+
assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {}
792+
assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {}
793+
assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"]
794+
assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"]
795+
assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {}
796+
797+
733798
@pytest.mark.parametrize(
734799
("train_datasets", "val_datasets"),
735800
[([RandomDataset], [RandomDataset]), ([RandomDataset], [RandomDataset, RandomDataset])],

0 commit comments

Comments
 (0)