diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..938dca71525ec 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,19 +6,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- -## [Unreleased] - YYYY-MM-DD -### Added - -- - -### Changed - -- +### Fixed -### Removed +- Ensured ``LightningDataModule`` always exposes ``allow_zero_length_dataloader_with_multiple_devices`` so Trainer zero-length checks don't raise ``AttributeError`` when subclasses skip ``super().__init__`` ([#21358](https://github.com/Lightning-AI/pytorch-lightning/issues/21358)) -- +--- ## [2.6.0] - 2025-11-28 diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 07ec02ef87bd8..66027d8c343d1 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -76,6 +76,9 @@ def teardown(self): """ name: Optional[str] = None + # Fallback for subclasses that don't call ``super().__init__``. The attribute normally gets initialized in + # ``DataHooks.__init__`` but Trainer loops expect it to exist regardless. + allow_zero_length_dataloader_with_multiple_devices: bool = False CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 6d97edb241fe5..47706ae66b735 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -21,6 +21,7 @@ import pytest import torch +from torch.utils.data import DataLoader from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -29,6 +30,7 @@ BoringDataModuleNoLen, BoringModel, IterableBoringDataModule, + RandomDataset, ) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn @@ -274,6 +276,29 @@ def train_dataloader(self): trainer.fit(model, dm) +def test_datamodule_allow_zero_length_attr_without_super(tmp_path): + class DataModuleWithoutSuper(LightningDataModule): + def __init__(self): + self.data = RandomDataset(32, 4) + + def val_dataloader(self): + return DataLoader(self.data, batch_size=2) + + dm = DataModuleWithoutSuper() + assert dm.allow_zero_length_dataloader_with_multiple_devices is False + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=1, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + + trainer.validate(model, datamodule=dm) + + class DummyDS(torch.utils.data.Dataset): def __getitem__(self, index): return 1