Skip to content

Commit 300c41e

Browse files
committed
Fix LightningDataModule zero-length attribute
1 parent b09e96e commit 300c41e

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9+
10+
### Fixed
11+
12+
- Ensured ``LightningDataModule`` always exposes ``allow_zero_length_dataloader_with_multiple_devices`` so Trainer zero-length checks don't raise ``AttributeError`` when subclasses skip ``super().__init__`` ([#21358](https://github.com/Lightning-AI/pytorch-lightning/issues/21358))
13+
14+
---
15+
916
## [2.6.0] - 2025-11-21
1017

1118
### Added

src/lightning/pytorch/core/datamodule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def teardown(self):
7676
"""
7777

7878
name: Optional[str] = None
79+
# Fallback for subclasses that don't call ``super().__init__``. The attribute normally gets initialized in
80+
# ``DataHooks.__init__`` but Trainer loops expect it to exist regardless.
81+
allow_zero_length_dataloader_with_multiple_devices: bool = False
7982
CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
8083
CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
8184
CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytest
2323
import torch
24+
from torch.utils.data import DataLoader
2425

2526
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
2627
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -29,6 +30,7 @@
2930
BoringDataModuleNoLen,
3031
BoringModel,
3132
IterableBoringDataModule,
33+
RandomDataset,
3234
)
3335
from lightning.pytorch.profilers.simple import SimpleProfiler
3436
from lightning.pytorch.trainer.states import TrainerFn
@@ -274,6 +276,29 @@ def train_dataloader(self):
274276
trainer.fit(model, dm)
275277

276278

279+
def test_datamodule_allow_zero_length_attr_without_super(tmp_path):
280+
class DataModuleWithoutSuper(LightningDataModule):
281+
def __init__(self):
282+
self.data = RandomDataset(32, 4)
283+
284+
def val_dataloader(self):
285+
return DataLoader(self.data, batch_size=2)
286+
287+
dm = DataModuleWithoutSuper()
288+
assert dm.allow_zero_length_dataloader_with_multiple_devices is False
289+
290+
model = BoringModel()
291+
trainer = Trainer(
292+
default_root_dir=tmp_path,
293+
fast_dev_run=1,
294+
enable_model_summary=False,
295+
enable_checkpointing=False,
296+
logger=False,
297+
)
298+
299+
trainer.validate(model, datamodule=dm)
300+
301+
277302
class DummyDS(torch.utils.data.Dataset):
278303
def __getitem__(self, index):
279304
return 1

0 commit comments

Comments
 (0)