-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I want to use a custom distributed batch sampler at inference time.
The sampler looks like this:
class DistributedInferenceBatchSampler(DistributedSampler):
def __init__(self, dataset: Dataset,
batch_size: int = 1,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas=num_replicas, rank=rank,
shuffle=shuffle, seed=seed, drop_last=drop_last)
# do stuff
# sort data indices by datapoint length, batch up
# subsample batches for current rank
self.batches = # nested list [[b1_1, b1_2, ...], [b2_1, b2_2, ...], ...]
def __iter__(self) -> Iterator[T_co]:
return iter(self.batches)
def __len__(self) -> int:
return len(self.batches)
I use this dataloader inside my data module:
class DataModule(LightningDataModule):
# ...
def predict_dataloader(self, rank=None, num_replicas=None):
# define Dataset 'data'
bsampler = DistributedInferenceBatchSampler(dataset=data,
batch_size=4,
num_replicas=num_replicas,
rank=rank)
data_loader = DataLoader(data,
batch_sampler=bsampler)
return data_loader
I'm running inference like so:
trainer = Trainer( strategy='auto',
use_distributed_sampler=False, # using custom distributed batchsampler
accelerator=gpu,
deterministic=False,
enable_progress_bar=True,
enable_model_summary=True,
devices=devices
)
trainer.predict(model=self._model, datamodule=self._datamodule)
However, Lightning tries to replace my batch sampler despite the use_distributed_sample=False
flag because it always does so in predict mode, and fails because the sampler doesn't have the same signature as a Pytorch BatchSampler
.
I've tried wrapping my custom DistributedInferenceBatchSampler
like so:
class BatchSamplerWrapper(BatchSampler):
def __init__(self, sampler, batch_size=1, drop_last=False):
self.sampler = sampler
self.batch_size = batch_size # ignored
self.drop_last = drop_last # ignored
def __iter__(self):
for batch in self.sampler:
yield batch
def __len__(self):
return len(self.sampler)
class DataModule(LightningDataModule):
# ...
def predict_dataloader(self, rank=None, num_replicas=None):
# define Dataset 'data'
bsampler = DistributedInferenceBatchSampler(dataset=data,
batch_size=4,
num_replicas=num_replicas,
rank=rank)
wrapper = BatchSamplerWrapper(bsampler, batch_size=4, drop_last=False)
data_loader = DataLoader(data,
batch_sampler=wrapper)
return data_loader
However, Lightning replaces my bsampler
inside the wrapper with a torch.utils.data.sampler.SequentialSampler
which leads to BatchSamplerWrapper.__iter__()
not having the intended behaviour. It returns an int
rather than a list of int
s, leading to:
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Error executing job with overrides: []
Traceback (most recent call last):
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/codebase/IgFlow/multiflow/experiments/inference.py", line 18, in sample
run.sample()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/codebase/IgFlow/multiflow/experiments/model_run.py", line 125, in sample
trainer.predict(model=self._model, datamodule=self._datamodule)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 864, in predict
return call._call_and_handle_interrupt(
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
return function(*args, **kwargs)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 903, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
results = self._run_stage()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
return self.predict_loop.run()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
return loop_run(self, *args, **kwargs)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/prediction_loop.py", line 119, in run
batch, batch_idx, dataloader_idx = next(data_fetcher)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 127, in __next__
batch = super().__next__()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 56, in __next__
batch = next(self.iterator)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 326, in __next__
out = next(self._iterator)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 132, in __next__
out = next(self.iterators[0])
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
raise exception
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/data/nagagpu03/not-backed-up/nvme00/vavourakis/miniforge3/envs/mflow/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
TypeError: 'int' object is not iterable
I just want to turn off this sampler-replacing behaviour. I have a similar setup during training (rather than inference) and that works fine (no wrappers required, either).
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
No response
More info
No response