From 300c41e132409fde64bae89bb236199496f5e51a Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Wed, 26 Nov 2025 18:57:20 -0500 Subject: [PATCH] Fix LightningDataModule zero-length attribute --- src/lightning/pytorch/CHANGELOG.md | 7 ++++++ src/lightning/pytorch/core/datamodule.py | 3 +++ tests/tests_pytorch/core/test_datamodules.py | 25 ++++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b99e1c5969ccb..814dca5bf28b2 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- + +### Fixed + +- Ensured ``LightningDataModule`` always exposes ``allow_zero_length_dataloader_with_multiple_devices`` so Trainer zero-length checks don't raise ``AttributeError`` when subclasses skip ``super().__init__`` ([#21358](https://github.com/Lightning-AI/pytorch-lightning/issues/21358)) + +--- + ## [2.6.0] - 2025-11-21 ### Added diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 07ec02ef87bd8..66027d8c343d1 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -76,6 +76,9 @@ def teardown(self): """ name: Optional[str] = None + # Fallback for subclasses that don't call ``super().__init__``. The attribute normally gets initialized in + # ``DataHooks.__init__`` but Trainer loops expect it to exist regardless. + allow_zero_length_dataloader_with_multiple_devices: bool = False CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 6d97edb241fe5..47706ae66b735 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -21,6 +21,7 @@ import pytest import torch +from torch.utils.data import DataLoader from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -29,6 +30,7 @@ BoringDataModuleNoLen, BoringModel, IterableBoringDataModule, + RandomDataset, ) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn @@ -274,6 +276,29 @@ def train_dataloader(self): trainer.fit(model, dm) +def test_datamodule_allow_zero_length_attr_without_super(tmp_path): + class DataModuleWithoutSuper(LightningDataModule): + def __init__(self): + self.data = RandomDataset(32, 4) + + def val_dataloader(self): + return DataLoader(self.data, batch_size=2) + + dm = DataModuleWithoutSuper() + assert dm.allow_zero_length_dataloader_with_multiple_devices is False + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=1, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + + trainer.validate(model, datamodule=dm) + + class DummyDS(torch.utils.data.Dataset): def __getitem__(self, index): return 1