Skip to content

Commit 6a09f27

Browse files
ved1betapre-commit-ci[bot]Borda
authored
docs: overfit_batches uses same batch for train and val (#20731)
* fix: overfit_batches uses same batch for train and val * docs changes foor better understanding --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 72bb751 commit 6a09f27

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ overfit_batches
759759
Uses this much data of the training & validation set.
760760
If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it.
761761

762+
* When set to a value > 0, sequential sampling (no shuffling) is used
763+
* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data
764+
762765
Useful for quickly debugging or trying to overfit on purpose.
763766

764767
.. testcode::
@@ -769,11 +772,11 @@ Useful for quickly debugging or trying to overfit on purpose.
769772
# use only 1% of the train & val set
770773
trainer = Trainer(overfit_batches=0.01)
771774

772-
# overfit on 10 of the same batches
775+
# overfit on 10 consistent train batches & 10 consistent val batches
773776
trainer = Trainer(overfit_batches=10)
774777

775-
plugins
776-
^^^^^^^
778+
# debug using a single consistent train batch and a single consistent val batch
779+
777780

778781
:ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:
779782

@@ -895,7 +898,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``.
895898
896899
# if 0 (default)
897900
train_loader = model.train_dataloader()
898-
# or if using data module: datamodule.train_dataloader()
901+
# or if using data module: datamodule.train_dataloaders()
899902
for epoch in epochs:
900903
for batch in train_loader:
901904
...

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,23 @@ def _get_distributed_sampler(
244244

245245

246246
def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
247+
"""Resolve overfit batches by disabling shuffling.
248+
249+
When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent
250+
batches across epochs. Training and validation use different sets of data.
251+
252+
"""
247253
all_have_sequential_sampler = all(
248254
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
249255
)
250256
if all_have_sequential_sampler:
251257
return
258+
252259
rank_zero_warn(
253260
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
254261
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
255262
)
263+
256264
updated = [
257265
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
258266
for dl in combined_loader.flattened

tests/tests_pytorch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def restore_env_variables():
9595
"TF_GRPC_DEFAULT_OPTIONS",
9696
"XLA_FLAGS",
9797
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
98+
# TensorFlow and TPU related variables
99+
"TF2_BEHAVIOR",
100+
"TPU_ML_PLATFORM",
101+
"TPU_ML_PLATFORM_VERSION",
102+
"LD_LIBRARY_PATH",
103+
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
98104
}
99105
leaked_vars.difference_update(allowlist)
100106
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"

tests/tests_pytorch/trainer/flags/test_overfit_batches.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches():
170170
train_sampler = trainer.train_dataloader.sampler
171171
assert isinstance(train_sampler, DistributedSampler)
172172
assert train_sampler.shuffle is False
173+
174+
175+
def test_overfit_batches_same_batch_for_train_and_val(tmp_path):
176+
"""Test that when overfit_batches=1, the same batch is used for both training and validation."""
177+
178+
class TestModel(BoringModel):
179+
def __init__(self):
180+
super().__init__()
181+
self.train_batches = []
182+
self.val_batches = []
183+
184+
def training_step(self, batch, batch_idx):
185+
self.train_batches.append(batch)
186+
return super().training_step(batch, batch_idx)
187+
188+
def validation_step(self, batch, batch_idx):
189+
self.val_batches.append(batch)
190+
return super().validation_step(batch, batch_idx)
191+
192+
model = TestModel()
193+
trainer = Trainer(
194+
default_root_dir=tmp_path,
195+
max_epochs=2,
196+
overfit_batches=1,
197+
check_val_every_n_epoch=1,
198+
enable_model_summary=False,
199+
)
200+
trainer.fit(model)
201+
202+
# Verify that the same batch was used for both training and validation
203+
assert len(model.train_batches) > 0
204+
assert len(model.val_batches) > 0
205+
206+
# Compare the actual batch contents
207+
train_batch = model.train_batches[0]
208+
val_batch = model.val_batches[0]
209+
210+
# Check if the batches are identical
211+
assert torch.equal(train_batch, val_batch), (
212+
"Training and validation batches should be identical when overfit_batches=1"
213+
)

0 commit comments

Comments
 (0)