Skip to content

overfit_batches replaces custom sampler with Sequential sampler #21282

@nilsleh

Description

@nilsleh

Bug description

More complicated datasets for real-world applications can demand more sophisticated data sampling, that are not just indices but perhaps a dictionary with additional info to query the dataset and return samples. When using the overfit_batches setting, which should still be useful in these settings to test an implementation, the dataloaders fail, because in the overfit_batches setting any custom sampler passed to the dataset loader is replaced with a standard SequentialLoader from pytorch.

An example of a domain library that extensively makes use of custom query logic for geospatial data and is heavily integrated into lightning is torchgeo.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

"""Minimal reproducible example: Custom dict sampler with Lightning overfit_batches."""

import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from lightning import LightningModule, LightningDataModule, Trainer
from collections.abc import Iterator
import numpy as np


# ============================================================================
# 1. SIMPLE DATASET (accepts dict queries)
# ============================================================================

class SimpleQueryDataset(Dataset):
    """Simple dataset that accepts dictionary queries."""
    
    def __init__(self, n_items: int = 100):
        """Initialize with some dummy items."""
        self.n_items = n_items
    
    def __len__(self) -> int:
        return self.n_items
    
    def __getitem__(self, query: dict) -> dict:
        """Accept a dictionary query.
        
        Args:
            query: Dict with keys:
                - 'index': int identifier
                - 'param': some parameter value
        
        Returns:
            Dict with random dummy data
        """
        # Validate input type
        if not isinstance(query, dict):
            raise TypeError(
                f"Expected dict query, got {type(query)}. "
                f"Content: {query}"
            )
        
        index = query["index"]
        param = query["param"]
        
        # Generate deterministic data based on query
        n_points = 10 + (index % 20)
        
        return {
            "data": torch.randn(n_points, 3),
            "target": torch.randn(n_points, 1),
            "query_index": index,
            "query_param": param,
        }


# ============================================================================
# 2. SIMPLE QUERY SAMPLER (yields dict queries)
# ============================================================================

class SimpleQuerySampler(Sampler):
    """Simple sampler that yields dictionary queries."""
    
    def __init__(self, n_queries: int = 50, shuffle: bool = True, seed: int = 42):
        """Generate some simple queries."""
        self.n_queries = n_queries
        self.shuffle = shuffle
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        
        # Pre-generate queries
        self.queries = []
        for i in range(n_queries):
            self.queries.append({
                "index": i,
                "param": self.rng.uniform(0, 1),
            })
    
    def __iter__(self) -> Iterator[dict]:
        """Yield query dictionaries."""
        indices = list(range(self.n_queries))
        
        if self.shuffle:
            self.rng.shuffle(indices)
        
        for idx in indices:
            yield self.queries[idx]
    
    def __len__(self) -> int:
        return self.n_queries


# ============================================================================
# 3. LIGHTNING DATAMODULE
# ============================================================================

class SimpleDataModule(LightningDataModule):
    """Lightning DataModule using dict-based query sampler."""
    
    def __init__(self, batch_size: int = 4, num_workers: int = 0):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Create dataset
        self.dataset = SimpleQueryDataset(n_items=100)
    
    def train_dataloader(self) -> DataLoader:
        sampler = SimpleQuerySampler(n_queries=50, shuffle=True, seed=42)
        
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )
    
    def val_dataloader(self) -> DataLoader:
        sampler = SimpleQuerySampler(n_queries=20, shuffle=False, seed=43)
        
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            collate_fn=self._collate_fn,
        )
    
    def _collate_fn(self, batch: list[dict]) -> dict:
        """Simple collate function."""
        # Stack data across batch
        data = torch.cat([sample["data"] for sample in batch], dim=0)
        target = torch.cat([sample["target"] for sample in batch], dim=0)
        
        return {
            "data": data,
            "target": target,
            "batch_sizes": [sample["data"].shape[0] for sample in batch],
            "query_indices": [sample["query_index"] for sample in batch],
        }


# ============================================================================
# 4. SIMPLE MODEL
# ============================================================================

class SimpleModel(LightningModule):
    """Simple model for testing."""
    
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 1)
    
    def forward(self, x):
        return self.linear(x)
    
    def training_step(self, batch, batch_idx):
        data = batch["data"]
        target = batch["target"]
        pred = self(data)
        loss = torch.nn.functional.mse_loss(pred, target)
        
        self.log("train_loss", loss)
        print(f"  Train batch {batch_idx}: query_indices={batch['query_indices']}")
        return loss
    
    def validation_step(self, batch, batch_idx):
        data = batch["data"]
        target = batch["target"]
        pred = self(data)
        loss = torch.nn.functional.mse_loss(pred, target)
        
        self.log("val_loss", loss)
        print(f"  Val batch {batch_idx}: query_indices={batch['query_indices']}")
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


# ============================================================================
# TEST FUNCTIONS
# ============================================================================

def test_lightning_overfit():
    """Test: Lightning WITH overfit_batches (THE FAILING CASE)."""
    datamodule = SimpleDataModule(batch_size=2, num_workers=0)
    model = SimpleModel()
    
    trainer = Trainer(
        max_epochs=1,
        accelerator="cpu",
        enable_checkpointing=False,
        logger=False,
        enable_model_summary=False,
        overfit_batches=2,  # This causes the issue
    )
    
    trainer.fit(model, datamodule)

if __name__ == "__main__":
    test_lightning_overfit()

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @tchaton

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions