-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
As title mentioned, there are error when using custom batch sampler in trainer.predict(), but trainer.fit() and trainer.test() works fine.
What version are you seeing the problem on?
v2.5
How to reproduce the bug
here is my custom batch sampler
class MILSampler(torch.utils.data.BatchSampler):
''' this is a batch sampler '''
weights: torch.Tensor
def __init__(
self,
sampler,
batch_size: int,
drop_last: bool,
data_source: Dataset,
shuffle: bool = False,
weights: Sequence[float] | None = None,
):
super().__init__(None, 2, False)
self.bag_images = data_source.bag_images
self.batch_size = batch_size
self.shuffle = shuffle
# for exp: limit bag size
# self.data_source = data_source
self.limit_bag_size = 20
self.length = 0
for bag in self.bag_images.values():
d, i = divmod(min(len(bag), self.limit_bag_size), self.batch_size)
self.length += d + (i > 0)
if weights is not None:
weights = torch.as_tensor(weights, dtype=torch.double)
if len(weights.shape) != 1:
raise ValueError(
"weights should be a 1d sequence but given "
f"weights have shape {tuple(weights.shape)}"
)
if len(weights) != len(self.bag_images):
raise ValueError(
"weights should have the same length as bag_images "
f"but given weights have length {len(weights)} "
f"and bag_images have length {len(self.bag_images)}"
)
self.weights = weights
@property
def num_samples(self) -> int:
return sum(len(bag) for bag in self.bag_images.values())
def __len__(self) -> int:
return self.length
def __iter__(self) -> Iterator[list[int]]:
if self.weights is not None:
bag_indexes = torch.multinomial(
self.weights, self.num_samples, replacement=True
)
else:
bag_indexes = (
torch.randperm(len(self.bag_images))
if self.shuffle
else torch.arange(len(self.bag_images))
)
for bag_index in bag_indexes.tolist():
bag = self.bag_images[bag_index]
if len(bag) > self.limit_bag_size:
bag = random.sample(bag, self.limit_bag_size)
data = [None for _ in range(len(bag))]
for i in range(len(bag)):
data[i] = (bag[i], False)
if i == len(bag) - 1:
data[i] = (bag[i], True)
for idx in range(0, len(bag), self.batch_size):
yield data[idx : idx + self.batch_size]
Error messages and logs
File "main_mil_v1.py", line 122, in main
res = trainer.predict(model=pipeline, dataloaders=val_dl)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 887, in predict
return call._call_and_handle_interrupt(
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 928, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 1012, in _run
results = self._run_stage()
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 1051, in _run_stage
return self.predict_loop.run()
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\utilities.py", line 179, in _decorator
return loop_run(self, *args, **kwargs)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\prediction_loop.py", line 105, in run
self.setup_data()
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\prediction_loop.py", line 158, in setup_data
dl = _process_dataloader(trainer, trainer_fn, stage, dl)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py", line 485, in _process_dataloader
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py", line 191, in _prepare_dataloader
return _update_dataloader(dataloader, sampler, mode=mode)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 135, in _update_dataloader
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 194, in _get_dataloader_init_args_and_kwargs
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode))
File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 284, in _dataloader_init_kwargs_resolve_sampler
batch_sampler = batch_sampler_cls(
TypeError: MILSampler.__init__() missing 1 required positional argument: 'data_source'
Environment
Current environment
PyTorch Lightning Version (e.g., 2.5.0): 2.5.1
PyTorch Version (e.g., 2.5): 2.5.1+cu121
Python version (e.g., 3.12): 3.10
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x