Skip to content

Commit 0cd837f

Browse files
Add a migration for the dataloader loops (#17125)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4ce1b65 commit 0cd837f

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/lightning/pytorch/utilities/migration/migration.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
5050
_drop_apex_amp_state,
5151
_migrate_loop_structure_after_tbptt_removal,
5252
_migrate_loop_structure_after_optimizer_loop_removal,
53+
_migrate_loop_structure_after_dataloader_loop_removal,
5354
],
5455
}
5556

@@ -236,7 +237,8 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE
236237
"""
237238
if "loops" not in checkpoint:
238239
return checkpoint
239-
240+
if "fit_loop" not in checkpoint["loops"]:
241+
return checkpoint
240242
fit_loop = checkpoint["loops"]["fit_loop"]
241243

242244
# remap `x.batch_loop.y` to `x.y`
@@ -273,8 +275,10 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
273275
"""
274276
if "loops" not in checkpoint:
275277
return checkpoint
276-
278+
if "fit_loop" not in checkpoint["loops"]:
279+
return checkpoint
277280
fit_loop = checkpoint["loops"]["fit_loop"]
281+
278282
# optimizer_position is no longer used
279283
if "epoch_loop.optimizer_loop.optim_progress" in fit_loop:
280284
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
@@ -291,3 +295,25 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
291295
"epoch_loop.manual_loop.optim_step_progress"
292296
)
293297
return checkpoint
298+
299+
300+
def _migrate_loop_structure_after_dataloader_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
301+
"""The dataloader loops (``_DataLoaderLoop``, ``_PredictionLoop`, and ``_EvaluationLoop``) were flattened into
302+
the ``_EvaluationEpochLoop`` (now ``_EvaluationLoop``) and ``_PredictionEpochLoop`` (now ``_PredictionLoop``).
303+
304+
Version: 2.0.0
305+
Commit: ec4f592ecfe238edd83185f6c6905fb1e2406d61
306+
PR: #16726
307+
"""
308+
if "loops" not in checkpoint:
309+
return checkpoint
310+
loops = checkpoint["loops"]
311+
for loop_key in ("predict_loop", "validate_loop", "test_loop"):
312+
if loop_key not in loops:
313+
continue
314+
loop = loops[loop_key]
315+
loop.pop("dataloader_progress", None) # no longer used
316+
epoch_loop_key = "epoch_loop."
317+
epoch_loop_dict = {k[len(epoch_loop_key) :]: loop.pop(k) for k in list(loop) if k.startswith(epoch_loop_key)}
318+
loop.update(epoch_loop_dict)
319+
return checkpoint

tests/tests_pytorch/utilities/migration/test_migration.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,52 @@ def test_migrate_loop_structure_after_optimizer_loop_removal():
227227
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
228228
}
229229
}
230+
231+
232+
def test_migrate_loop_structure_after_dataloader_loop_removal():
233+
"""Test the loop state migration after the dataloader loops were removed in 2.0.0."""
234+
old_dataloader_loop_state_dict = {
235+
"state_dict": {},
236+
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
237+
"epoch_loop.state_dict": {},
238+
"epoch_loop.batch_progress": {
239+
"total": {"ready": 123, "started": 0, "processed": 0, "completed": 0},
240+
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
241+
"is_last_batch": False,
242+
},
243+
}
244+
old_checkpoint = {
245+
"loops": {
246+
"predict_loop": old_dataloader_loop_state_dict,
247+
"validate_loop": dict(old_dataloader_loop_state_dict), # copy
248+
"test_loop": dict(old_dataloader_loop_state_dict), # copy
249+
}
250+
}
251+
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
252+
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
253+
assert updated_checkpoint["loops"] == {
254+
"predict_loop": {
255+
"batch_progress": {
256+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
257+
"is_last_batch": False,
258+
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
259+
},
260+
"state_dict": {},
261+
},
262+
"test_loop": {
263+
"batch_progress": {
264+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
265+
"is_last_batch": False,
266+
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
267+
},
268+
"state_dict": {},
269+
},
270+
"validate_loop": {
271+
"batch_progress": {
272+
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
273+
"is_last_batch": False,
274+
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
275+
},
276+
"state_dict": {},
277+
},
278+
}

0 commit comments

Comments
 (0)