| 
19 | 19 | import pytest  | 
20 | 20 | from lightning_utilities.test.warning import no_warning_call  | 
21 | 21 | from torch import Tensor  | 
22 |  | -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler  | 
 | 22 | +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler  | 
23 | 23 | 
 
  | 
24 | 24 | import lightning.fabric  | 
25 | 25 | from lightning.fabric.utilities.distributed import DistributedSamplerWrapper  | 
 | 
30 | 30 |     _check_dataloader_iterable,  | 
31 | 31 |     _DataHookSelector,  | 
32 | 32 |     _DataLoaderSource,  | 
 | 33 | +    _is_simple_sampler_replaceable,  | 
 | 34 | +    _resolve_overfit_batches,  | 
33 | 35 |     _worker_check,  | 
34 | 36 |     warning_cache,  | 
35 | 37 | )  | 
@@ -696,3 +698,49 @@ def test_iterable_check_on_known_iterators():  | 
696 | 698 |     dataloader.__iter__ = Mock()  | 
697 | 699 |     _check_dataloader_iterable(dataloader, Mock(), Mock())  | 
698 | 700 |     dataloader.__iter__.assert_not_called()  | 
 | 701 | + | 
 | 702 | + | 
 | 703 | +def test_is_simple_sampler_replaceable():  | 
 | 704 | +    """Test that _is_simple_sampler_replaceable correctly identifies simple vs custom samplers."""  | 
 | 705 | +    dataset = RandomDataset(32, 64)  | 
 | 706 | + | 
 | 707 | +    assert _is_simple_sampler_replaceable(SequentialSampler(dataset)) is True  | 
 | 708 | +    assert _is_simple_sampler_replaceable(RandomSampler(dataset)) is True  | 
 | 709 | + | 
 | 710 | +    class CustomSampler(Sampler):  | 
 | 711 | +        def __init__(self, dataset):  | 
 | 712 | +            self.dataset = dataset  | 
 | 713 | + | 
 | 714 | +        def __iter__(self):  | 
 | 715 | +            return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])  | 
 | 716 | + | 
 | 717 | +        def __len__(self):  | 
 | 718 | +            return len(self.dataset)  | 
 | 719 | + | 
 | 720 | +    assert _is_simple_sampler_replaceable(CustomSampler(dataset)) is False  | 
 | 721 | + | 
 | 722 | + | 
 | 723 | +def test_resolve_overfit_batches_preserves_custom_sampler():  | 
 | 724 | +    """Test that _resolve_overfit_batches does not alter custom samplers."""  | 
 | 725 | +    dataset = RandomDataset(32, 64)  | 
 | 726 | + | 
 | 727 | +    class CustomDictSampler(Sampler):  | 
 | 728 | +        def __init__(self, dataset):  | 
 | 729 | +            self.dataset = dataset  | 
 | 730 | + | 
 | 731 | +        def __iter__(self):  | 
 | 732 | +            return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])  | 
 | 733 | + | 
 | 734 | +        def __len__(self):  | 
 | 735 | +            return len(self.dataset)  | 
 | 736 | + | 
 | 737 | +    custom_sampler = CustomDictSampler(dataset)  | 
 | 738 | +    dataloader = DataLoader(dataset, sampler=custom_sampler, batch_size=2)  | 
 | 739 | +    combined_loader = CombinedLoader([dataloader])  | 
 | 740 | +    original_sampler = dataloader.sampler  | 
 | 741 | + | 
 | 742 | +    _resolve_overfit_batches(combined_loader, RunningStage.TRAINING)  | 
 | 743 | + | 
 | 744 | +    assert combined_loader.flattened[0].sampler is original_sampler  | 
 | 745 | +    assert combined_loader.flattened[0].sampler is custom_sampler  | 
 | 746 | +    assert isinstance(combined_loader.flattened[0].sampler, CustomDictSampler)  | 
0 commit comments