Skip to content

Commit 3f9acd3

Browse files
committed
add testing
1 parent cafb429 commit 3f9acd3

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020
from lightning_utilities.test.warning import no_warning_call
2121
from torch import Tensor
22-
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler
22+
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler
2323

2424
import lightning.fabric
2525
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
@@ -30,6 +30,8 @@
3030
_check_dataloader_iterable,
3131
_DataHookSelector,
3232
_DataLoaderSource,
33+
_is_simple_sampler_replaceable,
34+
_resolve_overfit_batches,
3335
_worker_check,
3436
warning_cache,
3537
)
@@ -696,3 +698,49 @@ def test_iterable_check_on_known_iterators():
696698
dataloader.__iter__ = Mock()
697699
_check_dataloader_iterable(dataloader, Mock(), Mock())
698700
dataloader.__iter__.assert_not_called()
701+
702+
703+
def test_is_simple_sampler_replaceable():
704+
"""Test that _is_simple_sampler_replaceable correctly identifies simple vs custom samplers."""
705+
dataset = RandomDataset(32, 64)
706+
707+
assert _is_simple_sampler_replaceable(SequentialSampler(dataset)) is True
708+
assert _is_simple_sampler_replaceable(RandomSampler(dataset)) is True
709+
710+
class CustomSampler(Sampler):
711+
def __init__(self, dataset):
712+
self.dataset = dataset
713+
714+
def __iter__(self):
715+
return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])
716+
717+
def __len__(self):
718+
return len(self.dataset)
719+
720+
assert _is_simple_sampler_replaceable(CustomSampler(dataset)) is False
721+
722+
723+
def test_resolve_overfit_batches_preserves_custom_sampler():
724+
"""Test that _resolve_overfit_batches does not alter custom samplers."""
725+
dataset = RandomDataset(32, 64)
726+
727+
class CustomDictSampler(Sampler):
728+
def __init__(self, dataset):
729+
self.dataset = dataset
730+
731+
def __iter__(self):
732+
return iter([{"index": i, "param": 0.5} for i in range(len(self.dataset))])
733+
734+
def __len__(self):
735+
return len(self.dataset)
736+
737+
custom_sampler = CustomDictSampler(dataset)
738+
dataloader = DataLoader(dataset, sampler=custom_sampler, batch_size=2)
739+
combined_loader = CombinedLoader([dataloader])
740+
original_sampler = dataloader.sampler
741+
742+
_resolve_overfit_batches(combined_loader, RunningStage.TRAINING)
743+
744+
assert combined_loader.flattened[0].sampler is original_sampler
745+
assert combined_loader.flattened[0].sampler is custom_sampler
746+
assert isinstance(combined_loader.flattened[0].sampler, CustomDictSampler)

0 commit comments

Comments
 (0)