|
24 | 24 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
25 | 25 | from lightning.pytorch.loops import _Loop
|
26 | 26 | from lightning.pytorch.loops.progress import _BaseProgress
|
| 27 | +from lightning.pytorch.utilities import CombinedLoader |
27 | 28 | from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter
|
28 | 29 |
|
29 | 30 | from tests_pytorch.helpers.runif import RunIf
|
@@ -882,3 +883,94 @@ def on_validation_start(self):
|
882 | 883 | )
|
883 | 884 | trainer.fit(model)
|
884 | 885 | assert model.ran_assert
|
| 886 | + |
| 887 | + |
| 888 | +class NotStatefulIterable: |
| 889 | + def __init__(self, start=0): |
| 890 | + self.index = start |
| 891 | + |
| 892 | + def __iter__(self): |
| 893 | + for i in range(self.index, len(self)): |
| 894 | + self.index = i |
| 895 | + yield self.index |
| 896 | + |
| 897 | + def __len__(self): |
| 898 | + return 10 |
| 899 | + |
| 900 | + |
| 901 | +class StatefulIterable(NotStatefulIterable): |
| 902 | + def state_dict(self): |
| 903 | + return {"index": self.index} |
| 904 | + |
| 905 | + def load_state_dict(self, state_dict): |
| 906 | + self.index = state_dict["index"] + 1 |
| 907 | + |
| 908 | + |
| 909 | +@pytest.mark.parametrize( |
| 910 | + ("train_dataloader_factory", "has_state", "batches_before", "batches_after"), |
| 911 | + [ |
| 912 | + # No dataloader |
| 913 | + (lambda: [], False, [], []), |
| 914 | + # Single stateful DataLoader |
| 915 | + (lambda: StatefulIterable(), True, [0, 1], [2, 3]), |
| 916 | + # Single, not stateful DataLoader |
| 917 | + (lambda: CombinedLoader(NotStatefulIterable()), False, [0, 1], [0, 1]), |
| 918 | + # Single stateful DataLoader |
| 919 | + (lambda: CombinedLoader(StatefulIterable()), True, [0, 1], [2, 3]), |
| 920 | + # Multiple stateful DataLoaders |
| 921 | + (lambda: CombinedLoader([StatefulIterable(3), StatefulIterable(1)]), True, [[3, 1], [4, 2]], [[5, 3], [6, 4]]), |
| 922 | + # Mix of stateful and not stateful DataLoaders |
| 923 | + ( |
| 924 | + lambda: CombinedLoader([NotStatefulIterable(3), StatefulIterable(1), NotStatefulIterable(2)]), |
| 925 | + True, |
| 926 | + [[3, 1, 2], [4, 2, 3]], |
| 927 | + [[3, 3, 2], [4, 4, 3]], |
| 928 | + ), |
| 929 | + ], |
| 930 | +) |
| 931 | +def test_fit_loop_save_and_restore_dataloaders( |
| 932 | + train_dataloader_factory, has_state, batches_before, batches_after, tmp_path |
| 933 | +): |
| 934 | + """Test that the CheckpointConnector saves the state of stateful dataloaders.""" |
| 935 | + |
| 936 | + class DummyModel(BoringModel): |
| 937 | + def __init__(self): |
| 938 | + super().__init__() |
| 939 | + self.seen_data = [] |
| 940 | + |
| 941 | + def training_step(self, batch, batch_idx): |
| 942 | + self.seen_data.append(batch) |
| 943 | + print(batch) |
| 944 | + |
| 945 | + def train_dataloader(self): |
| 946 | + return train_dataloader_factory() |
| 947 | + |
| 948 | + trainer_kwargs = { |
| 949 | + "default_root_dir": tmp_path, |
| 950 | + "accelerator": "cpu", |
| 951 | + "enable_checkpointing": False, |
| 952 | + "enable_model_summary": False, |
| 953 | + "enable_progress_bar": False, |
| 954 | + "logger": False, |
| 955 | + "num_sanity_val_steps": 0, |
| 956 | + } |
| 957 | + |
| 958 | + # Train for 2 steps |
| 959 | + model = DummyModel() |
| 960 | + trainer = Trainer(**trainer_kwargs, max_steps=2) |
| 961 | + trainer.fit(model) |
| 962 | + assert model.seen_data == batches_before |
| 963 | + |
| 964 | + # Save a checkpoint |
| 965 | + trainer.save_checkpoint(tmp_path / "checkpoint.ckpt") |
| 966 | + checkpoint = torch.load(tmp_path / "checkpoint.ckpt") |
| 967 | + if has_state: |
| 968 | + assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"] |
| 969 | + else: |
| 970 | + assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"] |
| 971 | + |
| 972 | + # Restore training from step 2 and continue 2 more steps |
| 973 | + model = DummyModel() |
| 974 | + trainer = Trainer(**trainer_kwargs, max_steps=4) |
| 975 | + trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) |
| 976 | + assert model.seen_data == batches_after |
0 commit comments