Skip to content

Commit caeea6e

Browse files
committed
Add test to reproduce OverflowError exception
1 parent 1f5add3 commit caeea6e

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
import yaml
3131
from jsonargparse import ArgumentParser
3232
from torch import optim
33+
from torch.utils.data.dataloader import DataLoader
3334

3435
import lightning.pytorch as pl
3536
from lightning.fabric.utilities.cloud_io import _load as pl_load
3637
from lightning.pytorch import Trainer, seed_everything
3738
from lightning.pytorch.callbacks import ModelCheckpoint
38-
from lightning.pytorch.demos.boring_classes import BoringModel
39+
from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset
3940
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
4041
from lightning.pytorch.utilities.exceptions import MisconfigurationException
4142
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
@@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected):
16241625
parser.add_argument("--a", type=annot)
16251626
args = parser.parse_args(["--a", val])
16261627
assert args.a == expected
1628+
1629+
1630+
def test_load_with_inf_data_loader(tmp_path):
1631+
"""Test loading from a checkpoint with a dataloader that does not have a length."""
1632+
# Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565
1633+
dataset = RandomIterableDataset(size=32, count=10)
1634+
1635+
class ModelWithIterableDataset(BoringModel):
1636+
def train_dataloader(self) -> DataLoader:
1637+
return DataLoader(dataset)
1638+
1639+
def val_dataloader(self) -> DataLoader:
1640+
return DataLoader(dataset)
1641+
1642+
model = ModelWithIterableDataset()
1643+
with pytest.raises(TypeError):
1644+
len(model.train_dataloader())
1645+
1646+
trainer_kwargs = {
1647+
"default_root_dir": tmp_path,
1648+
"max_epochs": 2,
1649+
"limit_train_batches": 2,
1650+
"limit_val_batches": None,
1651+
"check_val_every_n_epoch": 1,
1652+
"enable_model_summary": False,
1653+
"logger": False,
1654+
}
1655+
mc_kwargs = {
1656+
"save_last": True,
1657+
"every_n_train_steps": 1,
1658+
}
1659+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1660+
trainer.fit(model)
1661+
1662+
checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt"
1663+
assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints")
1664+
1665+
# Resume from checkpoint and run for more epochs
1666+
trainer_kwargs["max_epochs"] = 4
1667+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1668+
trainer.fit(model, ckpt_path=checkpoint_path)

0 commit comments

Comments
 (0)