Skip to content

PyTorch Lightning doesn't perform validation with StreamingDataset in a special case #772

@senarvi

Description

@senarvi

🐛 Bug

I think this is probably related to issue #19624 of PyTorch Lightning: when an IterableDataset defines __len__() (like StreamingDataset), dataset size doesn't evenly divide into batches, and incomplete batches are dropped by multiple workers, the data loader will raise a StopIteration, and the validation epoch will be skipped.

This problem was already reported and fixed in issue #133. However, the problem still persist if the user queries the dataset length first. @enrico-stauss and @tchaton solved that issue. Do you have an idea whether you maybe overlooked some corner case?

To Reproduce

See the updated code sample for how to reproduce the issue.

Code sample
from pathlib import Path

import torch
from litdata import StreamingDataLoader, StreamingDataset, optimize
from pytorch_lightning import LightningDataModule, LightningModule, Trainer

DATA_ROOT = Path("data")
DROP_LAST = True
BATCH_SIZE = 8
NUM_WORKERS = 4
QUERY_DATASET_LENGTH = True


def generate_sample(i):
    return {"index": i, "data": torch.rand((3, 20, 20))}


class MyDataModule(LightningDataModule):
    def prepare_data(self):
        train_dir = DATA_ROOT / "train"
        if not train_dir.is_dir():
            train_dir.mkdir(parents=True)
            optimize(fn=generate_sample, inputs=list(range(300)), output_dir=train_dir, chunk_size=25, num_workers=4)

        val_dir = DATA_ROOT / "val"
        if not val_dir.is_dir():
            val_dir.mkdir(parents=True)
            optimize(fn=generate_sample, inputs=list(range(100)), output_dir=val_dir, chunk_size=25, num_workers=4)

    def train_dataloader(self):
        dataset = StreamingDataset(DATA_ROOT / "train")
        if QUERY_DATASET_LENGTH:
            len(dataset)
        return StreamingDataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=DROP_LAST)

    def val_dataloader(self):
        dataset = StreamingDataset(DATA_ROOT / "val")
        return StreamingDataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)


class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)
        self.did_validation = False

    def training_step(self, batch):
        loss = self.model(batch["data"]).mean()
        return loss

    def on_train_epoch_start(self):
        print("TRAINING START")

    def on_train_epoch_end(self):
        print("TRAINING END")

    def validation_step(self, batch):
        self.did_validation = True
        loss = self.model(batch["data"]).mean()
        return loss

    def on_validation_epoch_start(self):
        print("VALIDATION START")

    def on_validation_epoch_end(self):
        print("VALIDATION END")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


def main():
    model = MyModel()
    trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
    datamodule = MyDataModule()
    trainer.fit(model, datamodule=datamodule)
    print("Performed validation:", model.did_validation)


if __name__ == "__main__":
    main()

Expected behavior

It's expected that validation is performed. But when the batch size is selected conveniently, incomplete batches are dropped, and len(dataset) is called, validation_step() won't be called.

Additional context

We can argue about whether this should be fixed in PyTorch, PyTorch Lightning, or LitData, but one would expect that those libraries work well together, or at least detect the problematic edge cases and display an error.

Environment detail
  • PyTorch Version (e.g., 1.0): 2.9.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.13.7
  • GPU models and configuration: occurs in both CPU-only and GPU configurations

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions