Skip to content

Commit 1278308

Browse files
authored
Fix OverflowError when resuming from checkpoint with an iterable dataset (#20624)
* Add test to reproduce OverflowError exception * Don't increment batch progress in eval loop with inf max batch * Update changelog
1 parent 7900105 commit 1278308

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232

3333
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
3434

35+
- Fix OverflowError when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565))
36+
3537

3638
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
3739

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
import os
1516
import shutil
1617
import sys
@@ -268,7 +269,10 @@ def increment_progress_to_evaluation_end(self) -> None:
268269
if self.skip:
269270
return
270271
self.reset()
271-
max_batch = int(max(self.max_batches))
272+
max_batch = max(self.max_batches)
273+
if isinstance(max_batch, float) and math.isinf(max_batch):
274+
return
275+
max_batch = int(max_batch)
272276
if max_batch == -1:
273277
return
274278
self.batch_progress.increment_by(max_batch, True)

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
import yaml
3131
from jsonargparse import ArgumentParser
3232
from torch import optim
33+
from torch.utils.data.dataloader import DataLoader
3334

3435
import lightning.pytorch as pl
3536
from lightning.fabric.utilities.cloud_io import _load as pl_load
3637
from lightning.pytorch import Trainer, seed_everything
3738
from lightning.pytorch.callbacks import ModelCheckpoint
38-
from lightning.pytorch.demos.boring_classes import BoringModel
39+
from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset
3940
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
4041
from lightning.pytorch.utilities.exceptions import MisconfigurationException
4142
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
@@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected):
16241625
parser.add_argument("--a", type=annot)
16251626
args = parser.parse_args(["--a", val])
16261627
assert args.a == expected
1628+
1629+
1630+
def test_load_with_inf_data_loader(tmp_path):
1631+
"""Test loading from a checkpoint with a dataloader that does not have a length."""
1632+
# Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565
1633+
dataset = RandomIterableDataset(size=32, count=10)
1634+
1635+
class ModelWithIterableDataset(BoringModel):
1636+
def train_dataloader(self) -> DataLoader:
1637+
return DataLoader(dataset)
1638+
1639+
def val_dataloader(self) -> DataLoader:
1640+
return DataLoader(dataset)
1641+
1642+
model = ModelWithIterableDataset()
1643+
with pytest.raises(TypeError):
1644+
len(model.train_dataloader())
1645+
1646+
trainer_kwargs = {
1647+
"default_root_dir": tmp_path,
1648+
"max_epochs": 2,
1649+
"limit_train_batches": 2,
1650+
"limit_val_batches": None,
1651+
"check_val_every_n_epoch": 1,
1652+
"enable_model_summary": False,
1653+
"logger": False,
1654+
}
1655+
mc_kwargs = {
1656+
"save_last": True,
1657+
"every_n_train_steps": 1,
1658+
}
1659+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1660+
trainer.fit(model)
1661+
1662+
checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt"
1663+
assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints")
1664+
1665+
# Resume from checkpoint and run for more epochs
1666+
trainer_kwargs["max_epochs"] = 4
1667+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1668+
trainer.fit(model, ckpt_path=checkpoint_path)

0 commit comments

Comments
 (0)