Skip to content

Commit cafb429

Browse files
committed
skip custom samplers
1 parent 4cd2336 commit cafb429

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,19 +244,48 @@ def _get_distributed_sampler(
244244
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
245245

246246

247+
def _is_simple_sampler_replaceable(sampler: Sampler) -> bool:
248+
"""Check if a sampler can be safely replaced with SequentialSampler for overfit batches."""
249+
simple_sampler_types = (
250+
RandomSampler,
251+
SequentialSampler,
252+
DistributedSampler,
253+
DistributedSamplerWrapper,
254+
UnrepeatedDistributedSamplerWrapper,
255+
)
256+
return isinstance(sampler, simple_sampler_types)
257+
258+
247259
def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
248260
"""Resolve overfit batches by disabling shuffling.
249261
250262
When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent
251263
batches across epochs. Training and validation use different sets of data.
252264
265+
For simple samplers (RandomSampler, SequentialSampler, etc.), they are replaced with SequentialSampler. For custom
266+
samplers that may use complex indexing, they are preserved but a warning is issued.
267+
253268
"""
254269
all_have_sequential_sampler = all(
255270
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
256271
)
257272
if all_have_sequential_sampler:
258273
return
259274

275+
# Check if any dataloaders have custom samplers that shouldn't be replaced
276+
has_custom_samplers = any(
277+
hasattr(dl, "sampler") and not _is_simple_sampler_replaceable(dl.sampler) for dl in combined_loader.flattened
278+
)
279+
280+
if has_custom_samplers:
281+
rank_zero_warn(
282+
f"You requested to overfit but some {mode.dataloader_prefix} dataloaders use custom samplers. "
283+
f"Custom samplers are preserved, but please ensure they provide deterministic, non-shuffled output "
284+
f"for consistent overfitting behavior.",
285+
category=PossibleUserWarning,
286+
)
287+
return
288+
260289
rank_zero_warn(
261290
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
262291
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."

0 commit comments

Comments
 (0)