|
30 | 30 | import yaml
|
31 | 31 | from jsonargparse import ArgumentParser
|
32 | 32 | from torch import optim
|
| 33 | +from torch.utils.data.dataloader import DataLoader |
33 | 34 |
|
34 | 35 | import lightning.pytorch as pl
|
35 | 36 | from lightning.fabric.utilities.cloud_io import _load as pl_load
|
36 | 37 | from lightning.pytorch import Trainer, seed_everything
|
37 | 38 | from lightning.pytorch.callbacks import ModelCheckpoint
|
38 |
| -from lightning.pytorch.demos.boring_classes import BoringModel |
| 39 | +from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset |
39 | 40 | from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
40 | 41 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
41 | 42 | from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
|
@@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected):
|
1624 | 1625 | parser.add_argument("--a", type=annot)
|
1625 | 1626 | args = parser.parse_args(["--a", val])
|
1626 | 1627 | assert args.a == expected
|
| 1628 | + |
| 1629 | + |
| 1630 | +def test_load_with_inf_data_loader(tmp_path): |
| 1631 | + """Test loading from a checkpoint with a dataloader that does not have a length.""" |
| 1632 | + # Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565 |
| 1633 | + dataset = RandomIterableDataset(size=32, count=10) |
| 1634 | + |
| 1635 | + class ModelWithIterableDataset(BoringModel): |
| 1636 | + def train_dataloader(self) -> DataLoader: |
| 1637 | + return DataLoader(dataset) |
| 1638 | + |
| 1639 | + def val_dataloader(self) -> DataLoader: |
| 1640 | + return DataLoader(dataset) |
| 1641 | + |
| 1642 | + model = ModelWithIterableDataset() |
| 1643 | + with pytest.raises(TypeError): |
| 1644 | + len(model.train_dataloader()) |
| 1645 | + |
| 1646 | + trainer_kwargs = { |
| 1647 | + "default_root_dir": tmp_path, |
| 1648 | + "max_epochs": 2, |
| 1649 | + "limit_train_batches": 2, |
| 1650 | + "limit_val_batches": None, |
| 1651 | + "check_val_every_n_epoch": 1, |
| 1652 | + "enable_model_summary": False, |
| 1653 | + "logger": False, |
| 1654 | + } |
| 1655 | + mc_kwargs = { |
| 1656 | + "save_last": True, |
| 1657 | + "every_n_train_steps": 1, |
| 1658 | + } |
| 1659 | + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) |
| 1660 | + trainer.fit(model) |
| 1661 | + |
| 1662 | + checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt" |
| 1663 | + assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints") |
| 1664 | + |
| 1665 | + # Resume from checkpoint and run for more epochs |
| 1666 | + trainer_kwargs["max_epochs"] = 4 |
| 1667 | + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) |
| 1668 | + trainer.fit(model, ckpt_path=checkpoint_path) |
0 commit comments