Skip to content

Commit 34a34a0

Browse files
authored
Enable saving and loading stateful DataLoaders in Trainer (#19361)
1 parent 5d178d0 commit 34a34a0

File tree

5 files changed

+192
-1
lines changed

5 files changed

+192
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381))
3131

3232

33+
- Added support for saving and loading stateful training DataLoaders ([#19361](https://github.com/Lightning-AI/lightning/pull/19361))
34+
35+
3336
### Changed
3437

3538
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Optional, Union
15+
from typing import Any, Dict, List, Optional, Union
1616

1717
import torch
1818
from typing_extensions import override
@@ -94,6 +94,7 @@ def __init__(
9494

9595
self._data_source = _DataLoaderSource(None, "train_dataloader")
9696
self._combined_loader: Optional[CombinedLoader] = None
97+
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
9798
self._data_fetcher: Optional[_DataFetcher] = None
9899
self._last_train_dl_reload_epoch = float("-inf")
99100

@@ -255,6 +256,8 @@ def setup_data(self) -> None:
255256

256257
combined_loader.limits = limits
257258

259+
self._load_combined_loader_states()
260+
258261
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
259262
self._data_fetcher.setup(combined_loader)
260263
iter(self._data_fetcher) # creates the iterator inside the fetcher
@@ -409,9 +412,27 @@ def teardown(self) -> None:
409412
self._data_fetcher = None
410413
self.epoch_loop.teardown()
411414

415+
@override
416+
def on_save_checkpoint(self) -> Dict:
417+
state_dict = super().on_save_checkpoint()
418+
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
419+
state_dict["combined_loader"] = loader_states
420+
return state_dict
421+
422+
@override
423+
def on_load_checkpoint(self, state_dict: Dict) -> None:
424+
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
425+
super().on_load_checkpoint(state_dict)
426+
412427
def _should_accumulate(self) -> bool:
413428
"""Whether the gradients should be accumulated."""
414429
return self.epoch_loop._should_accumulate()
415430

416431
def _iteration_based_training(self) -> bool:
417432
return self.trainer.max_steps != -1
433+
434+
def _load_combined_loader_states(self) -> None:
435+
if not self.restarting or not self._combined_loader_states_to_load or self._combined_loader is None:
436+
return
437+
self._combined_loader._load_state_dicts(self._combined_loader_states_to_load)
438+
self._combined_loader_states_to_load = [] # release memory

src/lightning/pytorch/utilities/combined_loader.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing_extensions import Self, TypedDict, override
2020

2121
from lightning.fabric.utilities.data import sized_len
22+
from lightning.fabric.utilities.types import _Stateful
2223
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten
2324

2425
_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx
@@ -374,6 +375,24 @@ def _dataset_length(self) -> int:
374375
fn = _SUPPORTED_MODES[self._mode]["fn"]
375376
return fn(lengths)
376377

378+
def _state_dicts(self) -> List[Dict[str, Any]]:
379+
"""Returns the list of state dicts for iterables in `self.flattened` that are stateful."""
380+
return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)]
381+
382+
def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None:
383+
"""Loads the state dicts for iterables in `self.flattened` that are stateful."""
384+
if not states:
385+
return
386+
stateful_loaders = [loader for loader in self.flattened if isinstance(loader, _Stateful)]
387+
if len(stateful_loaders) != len(states):
388+
raise RuntimeError(
389+
f"The CombinedLoader has {len(stateful_loaders)} stateful loaders, but found {len(states)} states"
390+
" in the checkpoint. Please make sure you define the same dataloaders that were used when saving"
391+
" the checkpoint."
392+
)
393+
for loader, state_dict in zip(stateful_loaders, states):
394+
loader.load_state_dict(state_dict)
395+
377396

378397
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
379398
if hasattr(dataloader, "_iterator"):

tests/tests_pytorch/loops/test_loops.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2525
from lightning.pytorch.loops import _Loop
2626
from lightning.pytorch.loops.progress import _BaseProgress
27+
from lightning.pytorch.utilities import CombinedLoader
2728
from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter
2829

2930
from tests_pytorch.helpers.runif import RunIf
@@ -882,3 +883,94 @@ def on_validation_start(self):
882883
)
883884
trainer.fit(model)
884885
assert model.ran_assert
886+
887+
888+
class NotStatefulIterable:
889+
def __init__(self, start=0):
890+
self.index = start
891+
892+
def __iter__(self):
893+
for i in range(self.index, len(self)):
894+
self.index = i
895+
yield self.index
896+
897+
def __len__(self):
898+
return 10
899+
900+
901+
class StatefulIterable(NotStatefulIterable):
902+
def state_dict(self):
903+
return {"index": self.index}
904+
905+
def load_state_dict(self, state_dict):
906+
self.index = state_dict["index"] + 1
907+
908+
909+
@pytest.mark.parametrize(
910+
("train_dataloader_factory", "has_state", "batches_before", "batches_after"),
911+
[
912+
# No dataloader
913+
(lambda: [], False, [], []),
914+
# Single stateful DataLoader
915+
(lambda: StatefulIterable(), True, [0, 1], [2, 3]),
916+
# Single, not stateful DataLoader
917+
(lambda: CombinedLoader(NotStatefulIterable()), False, [0, 1], [0, 1]),
918+
# Single stateful DataLoader
919+
(lambda: CombinedLoader(StatefulIterable()), True, [0, 1], [2, 3]),
920+
# Multiple stateful DataLoaders
921+
(lambda: CombinedLoader([StatefulIterable(3), StatefulIterable(1)]), True, [[3, 1], [4, 2]], [[5, 3], [6, 4]]),
922+
# Mix of stateful and not stateful DataLoaders
923+
(
924+
lambda: CombinedLoader([NotStatefulIterable(3), StatefulIterable(1), NotStatefulIterable(2)]),
925+
True,
926+
[[3, 1, 2], [4, 2, 3]],
927+
[[3, 3, 2], [4, 4, 3]],
928+
),
929+
],
930+
)
931+
def test_fit_loop_save_and_restore_dataloaders(
932+
train_dataloader_factory, has_state, batches_before, batches_after, tmp_path
933+
):
934+
"""Test that the CheckpointConnector saves the state of stateful dataloaders."""
935+
936+
class DummyModel(BoringModel):
937+
def __init__(self):
938+
super().__init__()
939+
self.seen_data = []
940+
941+
def training_step(self, batch, batch_idx):
942+
self.seen_data.append(batch)
943+
print(batch)
944+
945+
def train_dataloader(self):
946+
return train_dataloader_factory()
947+
948+
trainer_kwargs = {
949+
"default_root_dir": tmp_path,
950+
"accelerator": "cpu",
951+
"enable_checkpointing": False,
952+
"enable_model_summary": False,
953+
"enable_progress_bar": False,
954+
"logger": False,
955+
"num_sanity_val_steps": 0,
956+
}
957+
958+
# Train for 2 steps
959+
model = DummyModel()
960+
trainer = Trainer(**trainer_kwargs, max_steps=2)
961+
trainer.fit(model)
962+
assert model.seen_data == batches_before
963+
964+
# Save a checkpoint
965+
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
966+
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
967+
if has_state:
968+
assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
969+
else:
970+
assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]
971+
972+
# Restore training from step 2 and continue 2 more steps
973+
model = DummyModel()
974+
trainer = Trainer(**trainer_kwargs, max_steps=4)
975+
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
976+
assert model.seen_data == batches_after

tests/tests_pytorch/utilities/test_combined_loader.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
import math
1515
import pickle
1616
from typing import Any, NamedTuple, Sequence, get_args
17+
from unittest.mock import Mock
1718

1819
import pytest
1920
import torch
21+
from lightning.fabric.utilities.types import _Stateful
2022
from lightning.pytorch import Trainer
2123
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2224
from lightning.pytorch.utilities.combined_loader import (
@@ -602,3 +604,57 @@ def test_combined_loader_can_be_pickled():
602604

603605
# no error
604606
pickle.dumps(cl)
607+
608+
609+
def test_state_dicts():
610+
state1, state2, state3 = Mock(), Mock(), Mock()
611+
stateful1 = Mock(spec=_Stateful, state_dict=Mock(return_value=state1))
612+
stateful2 = Mock(spec=_Stateful, state_dict=Mock(return_value=state2))
613+
stateful3 = Mock(spec=_Stateful, state_dict=Mock(return_value=state3))
614+
615+
cl = CombinedLoader([])
616+
assert cl._state_dicts() == []
617+
cl = CombinedLoader([range(2)])
618+
assert cl._state_dicts() == []
619+
cl = CombinedLoader([stateful1])
620+
assert cl._state_dicts() == [state1]
621+
cl = CombinedLoader([range(2), stateful1])
622+
assert cl._state_dicts() == [state1]
623+
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
624+
assert cl._state_dicts() == [state1, state2]
625+
cl = CombinedLoader({"a": [range(2), stateful1], "b": [stateful2], "c": stateful3})
626+
assert cl._state_dicts() == [state1, state2, state3]
627+
628+
629+
def test_load_state_dicts():
630+
stateful1 = Mock(spec=_Stateful)
631+
stateful2 = Mock(spec=_Stateful)
632+
state1 = Mock()
633+
state2 = Mock()
634+
635+
# 0 stateful loaders, 1 state to load
636+
cl = CombinedLoader([range(2), range(3)])
637+
with pytest.raises(RuntimeError, match="has 0 stateful loaders, but found 1 states"):
638+
cl._load_state_dicts([{"state": 0}])
639+
640+
# 1 stateful loader, 0 states to load
641+
cl = CombinedLoader([stateful1, range(3)])
642+
cl._load_state_dicts([])
643+
stateful1.load_state_dict.assert_not_called()
644+
645+
# 1 stateful loader, 1 state to load
646+
cl = CombinedLoader([range(2), stateful1, range(3)])
647+
cl._load_state_dicts([state1])
648+
stateful1.load_state_dict.assert_called_with(state1)
649+
stateful1.reset_mock()
650+
651+
# 1 stateful loader, 2 states to load
652+
cl = CombinedLoader([range(2), stateful1, range(3)])
653+
with pytest.raises(RuntimeError, match="has 1 stateful loaders, but found 2 states"):
654+
cl._load_state_dicts([state1, state2])
655+
656+
# 2 stateful loaders, 2 states to load
657+
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
658+
cl._load_state_dicts([state1, state2])
659+
stateful1.load_state_dict.assert_called_with(state1)
660+
stateful2.load_state_dict.assert_called_with(state2)

0 commit comments

Comments
 (0)