diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8bc8e45989f77..03ec83f189f23 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) +- Fix OverflowError when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565)) + ## [2.5.0] - 2024-12-19 diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index d007466ee3b1c..7f033dbd8e2c2 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os import shutil import sys @@ -268,7 +269,10 @@ def increment_progress_to_evaluation_end(self) -> None: if self.skip: return self.reset() - max_batch = int(max(self.max_batches)) + max_batch = max(self.max_batches) + if isinstance(max_batch, float) and math.isinf(max_batch): + return + max_batch = int(max_batch) if max_batch == -1: return self.batch_progress.increment_by(max_batch, True) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 2351cc7548f79..1907a5fb35799 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -30,12 +30,13 @@ import yaml from jsonargparse import ArgumentParser from torch import optim +from torch.utils.data.dataloader import DataLoader import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE @@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected): parser.add_argument("--a", type=annot) args = parser.parse_args(["--a", val]) assert args.a == expected + + +def test_load_with_inf_data_loader(tmp_path): + """Test loading from a checkpoint with a dataloader that does not have a length.""" + # Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565 + dataset = RandomIterableDataset(size=32, count=10) + + class ModelWithIterableDataset(BoringModel): + def train_dataloader(self) -> DataLoader: + return DataLoader(dataset) + + def val_dataloader(self) -> DataLoader: + return DataLoader(dataset) + + model = ModelWithIterableDataset() + with pytest.raises(TypeError): + len(model.train_dataloader()) + + trainer_kwargs = { + "default_root_dir": tmp_path, + "max_epochs": 2, + "limit_train_batches": 2, + "limit_val_batches": None, + "check_val_every_n_epoch": 1, + "enable_model_summary": False, + "logger": False, + } + mc_kwargs = { + "save_last": True, + "every_n_train_steps": 1, + } + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model) + + checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt" + assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints") + + # Resume from checkpoint and run for more epochs + trainer_kwargs["max_epochs"] = 4 + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model, ckpt_path=checkpoint_path)