-
Notifications
You must be signed in to change notification settings - Fork 84
Description
🐛 Bug
I think this is probably related to issue #19624 of PyTorch Lightning: when an IterableDataset defines __len__() (like StreamingDataset), dataset size doesn't evenly divide into batches, and incomplete batches are dropped by multiple workers, the data loader will raise a StopIteration, and the validation epoch will be skipped.
This problem was already reported and fixed in issue #133. However, the problem still persist if the user queries the dataset length first. @enrico-stauss and @tchaton solved that issue. Do you have an idea whether you maybe overlooked some corner case?
To Reproduce
See the updated code sample for how to reproduce the issue.
Code sample
from pathlib import Path
import torch
from litdata import StreamingDataLoader, StreamingDataset, optimize
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
DATA_ROOT = Path("data")
DROP_LAST = True
BATCH_SIZE = 8
NUM_WORKERS = 4
QUERY_DATASET_LENGTH = True
def generate_sample(i):
return {"index": i, "data": torch.rand((3, 20, 20))}
class MyDataModule(LightningDataModule):
def prepare_data(self):
train_dir = DATA_ROOT / "train"
if not train_dir.is_dir():
train_dir.mkdir(parents=True)
optimize(fn=generate_sample, inputs=list(range(300)), output_dir=train_dir, chunk_size=25, num_workers=4)
val_dir = DATA_ROOT / "val"
if not val_dir.is_dir():
val_dir.mkdir(parents=True)
optimize(fn=generate_sample, inputs=list(range(100)), output_dir=val_dir, chunk_size=25, num_workers=4)
def train_dataloader(self):
dataset = StreamingDataset(DATA_ROOT / "train")
if QUERY_DATASET_LENGTH:
len(dataset)
return StreamingDataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, drop_last=DROP_LAST)
def val_dataloader(self):
dataset = StreamingDataset(DATA_ROOT / "val")
return StreamingDataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)
self.did_validation = False
def training_step(self, batch):
loss = self.model(batch["data"]).mean()
return loss
def on_train_epoch_start(self):
print("TRAINING START")
def on_train_epoch_end(self):
print("TRAINING END")
def validation_step(self, batch):
self.did_validation = True
loss = self.model(batch["data"]).mean()
return loss
def on_validation_epoch_start(self):
print("VALIDATION START")
def on_validation_epoch_end(self):
print("VALIDATION END")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def main():
model = MyModel()
trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
datamodule = MyDataModule()
trainer.fit(model, datamodule=datamodule)
print("Performed validation:", model.did_validation)
if __name__ == "__main__":
main()Expected behavior
It's expected that validation is performed. But when the batch size is selected conveniently, incomplete batches are dropped, and len(dataset) is called, validation_step() won't be called.
Additional context
We can argue about whether this should be fixed in PyTorch, PyTorch Lightning, or LitData, but one would expect that those libraries work well together, or at least detect the problematic edge cases and display an error.
Environment detail
- PyTorch Version (e.g., 1.0): 2.9.1
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): pip - Build command you used (if compiling from source):
- Python version: 3.13.7
- GPU models and configuration: occurs in both CPU-only and GPU configurations