Skip to content

Commit 6baa5cc

Browse files
rohitgr7tchaton
authored andcommitted
fix overfit_batch sampler replacement logic (#10486)
Co-authored-by: thomas chaton <[email protected]>
1 parent 4bc6e95 commit 6baa5cc

File tree

3 files changed

+82
-19
lines changed

3 files changed

+82
-19
lines changed

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,26 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
1313
- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702))
1414

15+
- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))
16+
17+
18+
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
19+
20+
21+
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
22+
23+
24+
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))
25+
26+
27+
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
28+
29+
30+
-
31+
32+
33+
-
34+
1535

1636
## [1.5.1] - 2021-11-09
1737

pytorch_lightning/trainer/data_loading.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,7 @@ def _reset_eval_dataloader(
438438
for loader_i in range(len(dataloaders)):
439439
loader = dataloaders[loader_i]
440440

441-
if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):
442-
441+
if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
443442
# when overfitting, the dataloader should not have sampler
444443
if self.overfit_batches > 0 and mode.evaluating:
445444
rank_zero_warn(
@@ -591,16 +590,17 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
591590

592591
@staticmethod
593592
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
594-
has_random_sampler = False
593+
all_have_sequential_sampler = True
595594

596-
def resolve_had_random_sampler(dataloader: DataLoader):
597-
nonlocal has_random_sampler
598-
if not has_random_sampler:
599-
has_random_sampler = isinstance(dataloader.sampler, RandomSampler)
595+
def resolve_has_no_sequential_sampler(dataloader: DataLoader):
596+
nonlocal all_have_sequential_sampler
597+
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
598+
dataloader.sampler, SequentialSampler
599+
)
600600

601-
apply_to_collection(dataloader, DataLoader, resolve_had_random_sampler)
601+
apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler)
602602

603-
if has_random_sampler:
603+
if not all_have_sequential_sampler:
604604
rank_zero_warn(
605605
"You requested to overfit but enabled training dataloader shuffling."
606606
" We are turning off the training dataloader shuffling for you."

tests/trainer/flags/test_overfit_batches.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414
import pytest
1515
import torch
16+
from torch.utils.data.sampler import Sampler, SequentialSampler
1617

1718
from pytorch_lightning import Trainer
1819
from tests.helpers.boring_model import BoringModel, RandomDataset
1920

2021

2122
def test_overfit_multiple_val_loaders(tmpdir):
22-
"""Tests that only training_step can be used."""
23+
"""Tests that overfit batches works with multiple val dataloaders."""
24+
val_dl_count = 2
25+
overfit_batches = 3
2326

2427
class TestModel(BoringModel):
2528
def validation_step(self, batch, batch_idx, dataloader_idx):
@@ -31,25 +34,65 @@ def validation_epoch_end(self, outputs) -> None:
3134
pass
3235

3336
def val_dataloader(self):
34-
dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64))
35-
dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64))
36-
return [dl1, dl2]
37+
dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)]
38+
return dls
3739

3840
model = TestModel()
3941

4042
trainer = Trainer(
41-
default_root_dir=tmpdir, max_epochs=2, overfit_batches=1, log_every_n_steps=1, enable_model_summary=False
43+
default_root_dir=tmpdir,
44+
max_epochs=2,
45+
overfit_batches=overfit_batches,
46+
log_every_n_steps=1,
47+
enable_model_summary=False,
4248
)
4349

4450
trainer.fit(model)
51+
assert trainer.num_training_batches == overfit_batches
52+
assert len(trainer.num_val_batches) == val_dl_count
53+
assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches)
4554

4655

47-
@pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0])
48-
def test_overfit_basic(tmpdir, overfit):
49-
"""Tests that only training_step can be used."""
56+
@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
57+
def test_overfit_basic(tmpdir, overfit_batches):
58+
"""Tests that only training_step can be used when overfitting."""
5059

5160
model = BoringModel()
61+
model.validation_step = None
62+
total_train_samples = len(BoringModel().train_dataloader())
5263

53-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit, enable_model_summary=False)
54-
64+
trainer = Trainer(
65+
default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit_batches, enable_model_summary=False
66+
)
5567
trainer.fit(model)
68+
69+
assert trainer.num_val_batches == []
70+
assert trainer.num_training_batches == int(
71+
overfit_batches * (1 if isinstance(overfit_batches, int) else total_train_samples)
72+
)
73+
74+
75+
def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
76+
class NonSequentialSampler(Sampler):
77+
def __init__(self, data_source):
78+
self.data_source = data_source
79+
80+
def __iter__(self):
81+
return iter(range(len(self.data_source)))
82+
83+
def __len__(self):
84+
return len(self.data_source)
85+
86+
class TestModel(BoringModel):
87+
def train_dataloader(self):
88+
dataset = RandomDataset(32, 64)
89+
sampler = NonSequentialSampler(dataset)
90+
return torch.utils.data.DataLoader(dataset, sampler=sampler)
91+
92+
model = TestModel()
93+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
94+
95+
with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
96+
trainer.fit(model)
97+
98+
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)

0 commit comments

Comments
 (0)