Skip to content

trainer.predict() raise error when using custom batch sampler #20741

@vb123er951

Description

@vb123er951

Bug description

As title mentioned, there are error when using custom batch sampler in trainer.predict(), but trainer.fit() and trainer.test() works fine.

What version are you seeing the problem on?

v2.5

How to reproduce the bug

here is my custom batch sampler

class MILSampler(torch.utils.data.BatchSampler):
    ''' this is a batch sampler '''
    weights: torch.Tensor

    def __init__(
        self,
        sampler,
        batch_size: int,
        drop_last: bool,
        data_source: Dataset,
        shuffle: bool = False,
        weights: Sequence[float] | None = None,
    ):
        super().__init__(None, 2, False)
        self.bag_images = data_source.bag_images
        self.batch_size = batch_size
        self.shuffle = shuffle

        # for exp: limit bag size
        # self.data_source = data_source
        self.limit_bag_size = 20

        self.length = 0
        for bag in self.bag_images.values():
            d, i = divmod(min(len(bag), self.limit_bag_size), self.batch_size)
            self.length += d + (i > 0)

        if weights is not None:
            weights = torch.as_tensor(weights, dtype=torch.double)
            if len(weights.shape) != 1:
                raise ValueError(
                    "weights should be a 1d sequence but given "
                    f"weights have shape {tuple(weights.shape)}"
                )
            if len(weights) != len(self.bag_images):
                raise ValueError(
                    "weights should have the same length as bag_images "
                    f"but given weights have length {len(weights)} "
                    f"and bag_images have length {len(self.bag_images)}"
                )
        self.weights = weights

    @property
    def num_samples(self) -> int:
        return sum(len(bag) for bag in self.bag_images.values())

    def __len__(self) -> int:
        return self.length

    def __iter__(self) -> Iterator[list[int]]:

        if self.weights is not None:
            bag_indexes = torch.multinomial(
                self.weights, self.num_samples, replacement=True
            )
        else:
            bag_indexes = (
                torch.randperm(len(self.bag_images))
                if self.shuffle
                else torch.arange(len(self.bag_images))
            )

        for bag_index in bag_indexes.tolist():
            bag = self.bag_images[bag_index]
            if len(bag) > self.limit_bag_size:
                bag = random.sample(bag, self.limit_bag_size)
            
            data  = [None for _ in range(len(bag))]
            for i in range(len(bag)):
                data[i] = (bag[i], False)
                if i == len(bag) - 1:
                    data[i] = (bag[i], True)
        
            for idx in range(0, len(bag), self.batch_size):
                yield data[idx : idx + self.batch_size]

Error messages and logs

File "main_mil_v1.py", line 122, in main
    res = trainer.predict(model=pipeline, dataloaders=val_dl)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 887, in predict
    return call._call_and_handle_interrupt(
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 928, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 1012, in _run
    results = self._run_stage()
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 1051, in _run_stage
    return self.predict_loop.run()
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\utilities.py", line 179, in _decorator
    return loop_run(self, *args, **kwargs)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\prediction_loop.py", line 105, in run
    self.setup_data()
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\loops\prediction_loop.py", line 158, in setup_data
    dl = _process_dataloader(trainer, trainer_fn, stage, dl)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py", line 485, in _process_dataloader
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=is_shuffled, mode=stage)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py", line 191, in _prepare_dataloader
    return _update_dataloader(dataloader, sampler, mode=mode)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 135, in _update_dataloader
    dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 194, in _get_dataloader_init_args_and_kwargs
    dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode))
  File "d:\uv_envs\smt2\lib\site-packages\lightning\pytorch\utilities\data.py", line 284, in _dataloader_init_kwargs_resolve_sampler
    batch_sampler = batch_sampler_cls(
TypeError: MILSampler.__init__() missing 1 required positional argument: 'data_source'

Environment

Current environment
PyTorch Lightning Version (e.g., 2.5.0): 2.5.1
PyTorch Version (e.g., 2.5): 2.5.1+cu121
Python version (e.g., 3.12): 3.10

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions