Skip to content

Cannot turn off sampler injection at inference time. #20253

@ovavourakis

Description

@ovavourakis

Bug description

I want to use a custom distributed batch sampler at inference time.
The sampler looks like this:

class DistributedInferenceBatchSampler(DistributedSampler):

    def __init__(self, dataset: Dataset,
                       batch_size: int = 1,
                       num_replicas: Optional[int] = None,
                       rank: Optional[int] = None, 
                       shuffle: bool = False,
                       seed: int = 0,
                       drop_last: bool = False,
                       
        ) -> None:
        super().__init__(dataset, num_replicas=num_replicas, rank=rank, 
                         shuffle=shuffle, seed=seed, drop_last=drop_last)

        # do stuff
        # sort data indices by datapoint length, batch up
        # subsample batches for current rank
        self.batches =  # nested list [[b1_1, b1_2, ...], [b2_1, b2_2, ...], ...] 
        
    def __iter__(self) -> Iterator[T_co]:
        return iter(self.batches)

    def __len__(self) -> int:
        return len(self.batches)

I use this dataloader inside my data module:

class DataModule(LightningDataModule):

    # ...

    def predict_dataloader(self, rank=None, num_replicas=None):

        # define Dataset 'data'
        
        bsampler = DistributedInferenceBatchSampler(dataset=data,
                                                    batch_size=4,
                                                    num_replicas=num_replicas,
                                                    rank=rank)

        data_loader = DataLoader(data,
                                batch_sampler=bsampler)

        return data_loader

I'm running inference like so:

trainer = Trainer(  strategy='auto',
                    use_distributed_sampler=False, # using custom distributed batchsampler
                    accelerator=gpu,
                    deterministic=False,
                    enable_progress_bar=True,
                    enable_model_summary=True,
                    devices=devices
                )
trainer.predict(model=self._model, datamodule=self._datamodule)

However, Lightning tries to replace my batch sampler despite the use_distributed_sample=False flag because it always does so in predict mode, and fails because the sampler doesn't have the same signature as a Pytorch BatchSampler.

I've tried wrapping my custom DistributedInferenceBatchSampler like so:

class BatchSamplerWrapper(BatchSampler):
    def __init__(self, sampler, batch_size=1, drop_last=False):
        self.sampler = sampler
        self.batch_size = batch_size # ignored
        self.drop_last = drop_last   # ignored

    def __iter__(self):
        for batch in self.sampler:
            yield batch

    def __len__(self):
        return len(self.sampler)

class DataModule(LightningDataModule):

    # ...

    def predict_dataloader(self, rank=None, num_replicas=None):

        # define Dataset 'data'
        
        bsampler = DistributedInferenceBatchSampler(dataset=data,
                                                    batch_size=4,
                                                    num_replicas=num_replicas,
                                                    rank=rank)

        wrapper = BatchSamplerWrapper(bsampler, batch_size=4, drop_last=False)

        data_loader = DataLoader(data,
                                batch_sampler=wrapper)

        return data_loader

However, Lightning replaces my bsampler inside the wrapper with a torch.utils.data.sampler.SequentialSampler which leads to BatchSamplerWrapper.__iter__() not having the intended behaviour. It returns an int rather than a list of ints, leading to:

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Error executing job with overrides: []
Traceback (most recent call last):
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/codebase/IgFlow/multiflow/experiments/inference.py", line 18, in sample
    run.sample()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/codebase/IgFlow/multiflow/experiments/model_run.py", line 125, in sample
    trainer.predict(model=self._model, datamodule=self._datamodule)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 864, in predict
    return call._call_and_handle_interrupt(
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 903, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
    return self.predict_loop.run()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/prediction_loop.py", line 119, in run
    batch, batch_idx, dataloader_idx = next(data_fetcher)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 127, in __next__
    batch = super().__next__()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 56, in __next__
    batch = next(self.iterator)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 326, in __next__
    out = next(self._iterator)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 132, in __next__
    out = next(self.iterators[0])
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
TypeError: 'int' object is not iterable

I just want to turn off this sampler-replacing behaviour. I have a similar setup during training (rather than inference) and that works fine (no wrappers required, either).

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

No response

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingver: 2.1.xwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions