Skip to content

DistributedSamplerWrapper does not pass on the .set_epoch call to the underlying sampler #21454

@mojtababahrami

Description

@mojtababahrami

Bug description

By default the Trainer calls .set_epoch of any simple sampler in the fit_loop.
However when a sampler is wrapped by DistributedSamplerWrapper it ignores passing this call to the underlying sampler. I suggest the following fix for this purpose:

class DistributedSamplerWrapper(DistributedSampler):
    # current implementation

    @override
    def set_epoch(self, epoch: int) -> None:
        super().set_epoch(epoch)
        self.dataset._sampler.set_epoch(epoch)

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @ethanwharris @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdatadistributedGeneric distributed-related topicver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions