Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))

- Fix OverflowError when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565))


## [2.5.0] - 2024-12-19

Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import shutil
import sys
Expand Down Expand Up @@ -268,7 +269,10 @@ def increment_progress_to_evaluation_end(self) -> None:
if self.skip:
return
self.reset()
max_batch = int(max(self.max_batches))
max_batch = max(self.max_batches)
if isinstance(max_batch, float) and math.isinf(max_batch):
return
max_batch = int(max_batch)
if max_batch == -1:
return
self.batch_progress.increment_by(max_batch, True)
Expand Down
44 changes: 43 additions & 1 deletion tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
import yaml
from jsonargparse import ArgumentParser
from torch import optim
from torch.utils.data.dataloader import DataLoader

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected):
parser.add_argument("--a", type=annot)
args = parser.parse_args(["--a", val])
assert args.a == expected


def test_load_with_inf_data_loader(tmp_path):
"""Test loading from a checkpoint with a dataloader that does not have a length."""
# Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565
dataset = RandomIterableDataset(size=32, count=10)

class ModelWithIterableDataset(BoringModel):
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset)

def val_dataloader(self) -> DataLoader:
return DataLoader(dataset)

model = ModelWithIterableDataset()
with pytest.raises(TypeError):
len(model.train_dataloader())

trainer_kwargs = {
"default_root_dir": tmp_path,
"max_epochs": 2,
"limit_train_batches": 2,
"limit_val_batches": None,
"check_val_every_n_epoch": 1,
"enable_model_summary": False,
"logger": False,
}
mc_kwargs = {
"save_last": True,
"every_n_train_steps": 1,
}
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
trainer.fit(model)

checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt"
assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints")

# Resume from checkpoint and run for more epochs
trainer_kwargs["max_epochs"] = 4
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
trainer.fit(model, ckpt_path=checkpoint_path)
Loading