Skip to content

Commit 62e3d58

Browse files
authored
Consume the prediction batch indices iteratively (#16826)
1 parent 598c247 commit 62e3d58

File tree

11 files changed

+120
-107
lines changed

11 files changed

+120
-107
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
122122
- The selection `Trainer(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))
123123

124124

125+
- Predict's custom BatchSampler that tracks the batch indices no longer consumes the entire batch sampler at the beginning ([#16826](https://github.com/Lightning-AI/lightning/pull/16826))
126+
127+
125128
### Deprecated
126129

127130
-
@@ -237,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
237240
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))
238241

239242

243+
- The `lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper` class is now marked as protected ([#16826](https://github.com/Lightning-AI/lightning/pull/16826))
244+
245+
240246
- Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))
241247

242248

@@ -362,6 +368,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
362368

363369
### Fixed
364370

371+
- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
372+
373+
365374
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
366375

367376

@@ -373,8 +382,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
373382
- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))
374383

375384

376-
- Fixed bug where `set_epoch` was not called for prediction dataloaders ([#16785](https://github.com/Lightning-AI/lightning/pull/16785))
377-
378385
## [1.9.1] - 2023-02-10
379386

380387
### Fixed

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lightning.pytorch.loops.loop import _Loop
1212
from lightning.pytorch.loops.progress import Progress
1313
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
14-
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
14+
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
1515
from lightning.pytorch.strategies import DDPSpawnStrategy
1616
from lightning.pytorch.trainer import call
1717
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
@@ -215,18 +215,14 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int
215215

216216
def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples
217217
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
218-
:class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`."""
218+
:class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`."""
219219
batch_sampler = getattr(dataloader, "batch_sampler", None)
220-
if not isinstance(batch_sampler, IndexBatchSamplerWrapper):
220+
if not isinstance(batch_sampler, _IndexBatchSamplerWrapper):
221221
self._warning_cache.warn(
222222
f"Couldn't infer the batch indices fetched from your dataloader: `{type(dataloader).__name__}`"
223223
)
224224
return []
225-
seen_batch_indices = batch_sampler.seen_batch_indices
226-
# TODO(carmocca): this could be avoided
227-
# we need to truncate the list because `IndexBatchSamplerWrapper` computes all indices on `__iter__`
228-
seen_batch_indices = seen_batch_indices[: (self.batch_progress.current.completed + 1)]
229-
return seen_batch_indices
225+
return batch_sampler.seen_batch_indices
230226

231227
def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) -> bool:
232228
prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)]
@@ -238,7 +234,7 @@ def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int)
238234
dataloader = combined_loader.flattened[dataloader_idx]
239235
batch_indices = self._get_batch_indices(dataloader)
240236
if not batch_indices:
241-
# this is only available with `IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is
237+
# this is only available with `_IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is
242238
# reached, it's likely because a non-DataLoader was passed
243239
return any_on_epoch
244240
batch_indices = batch_indices[batch_idx]

src/lightning/pytorch/loops/utilities.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,25 @@ def _reset_progress(loop: _Loop) -> None:
124124

125125

126126
def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
127-
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
127+
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
128128
129-
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
129+
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
130130
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
131131
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
132132
"""
133-
for sampler_name in ("sampler", "batch_sampler"):
134-
sampler = getattr(dataloader, sampler_name, None)
135-
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
136-
sampler.set_epoch(epoch)
133+
objects = set()
134+
# check dataloader.sampler
135+
if (sampler := getattr(dataloader, "sampler", None)) is not None:
136+
objects.add(sampler)
137+
# check dataloader.batch_sampler.sampler
138+
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
139+
sampler := getattr(batch_sampler, "sampler", None)
140+
) is not None:
141+
objects.add(sampler)
142+
for obj in objects:
143+
set_epoch = getattr(obj, "set_epoch", None)
144+
if callable(set_epoch):
145+
set_epoch(epoch)
137146

138147

139148
def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher:

src/lightning/pytorch/overrides/distributed.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
15+
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
1616

1717
import torch
1818
from torch import Tensor
@@ -108,34 +108,36 @@ def __iter__(self) -> Iterator:
108108
return (self.dataset[index] for index in super().__iter__())
109109

110110

111-
class IndexBatchSamplerWrapper:
111+
class _IndexBatchSamplerWrapper(BatchSampler):
112112
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""
113113

114-
def __init__(self, sampler: BatchSampler) -> None:
114+
def __init__(self, batch_sampler: BatchSampler) -> None:
115+
# do not call super().__init__() on purpose
115116
self.seen_batch_indices: List[List[int]] = []
116-
self._sampler = sampler
117+
118+
self.__dict__ = {
119+
k: v
120+
for k, v in batch_sampler.__dict__.items()
121+
if k not in ("__next__", "__iter__", "__len__", "__getstate__")
122+
}
123+
self._batch_sampler = batch_sampler
124+
self._iterator: Optional[Iterator[List[int]]] = None
125+
126+
def __next__(self) -> List[int]:
127+
assert self._iterator is not None
128+
batch = next(self._iterator)
129+
self.seen_batch_indices.append(batch)
130+
return batch
117131

118132
def __iter__(self) -> Iterator[List[int]]:
119133
self.seen_batch_indices = []
120-
for batch in self._sampler:
121-
self.seen_batch_indices.append(batch)
122-
yield batch
134+
self._iterator = iter(self._batch_sampler)
135+
return self
123136

124137
def __len__(self) -> int:
125-
return len(self._sampler)
126-
127-
@property
128-
def drop_last(self) -> bool:
129-
return self._sampler.drop_last
130-
131-
@property
132-
def batch_size(self) -> int:
133-
return self._sampler.batch_size
134-
135-
@property
136-
def sampler(self) -> Union[Sampler, Iterable]:
137-
return self._sampler.sampler
138+
return len(self._batch_sampler)
138139

139-
def set_epoch(self, epoch: int) -> None:
140-
if hasattr(self._sampler, "set_epoch"):
141-
self._sampler.set_epoch(epoch)
140+
def __getstate__(self) -> Dict[str, Any]:
141+
state = self.__dict__.copy()
142+
state["_iterator"] = None # cannot pickle 'generator' object
143+
return state

src/lightning/pytorch/utilities/data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
has_iterable_dataset,
3030
sized_len,
3131
)
32-
from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper
32+
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
3333
from lightning.pytorch.trainer.states import RunningStage
3434
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3535
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
@@ -246,7 +246,7 @@ def _dataloader_init_kwargs_resolve_sampler(
246246
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
247247
re-instantiation.
248248
249-
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
249+
If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so
250250
Lightning can keep track of its indices.
251251
252252
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
@@ -322,8 +322,9 @@ def _dataloader_init_kwargs_resolve_sampler(
322322
) from e
323323

324324
if is_predicting:
325-
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
325+
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler)
326326

327+
# batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last
327328
return {
328329
"sampler": None,
329330
"shuffle": False,

tests/tests_pytorch/loops/test_evaluation_loop.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,22 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
4343
assert eval_epoch_end_mock.call_count == 4
4444

4545

46-
def test_evaluation_loop_sampler_set_epoch_called(tmpdir):
46+
@pytest.mark.parametrize("use_batch_sampler", (False, True))
47+
def test_evaluation_loop_sampler_set_epoch_called(tmp_path, use_batch_sampler):
4748
"""Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation."""
4849

4950
def _get_dataloader():
5051
dataset = RandomDataset(32, 64)
5152
sampler = RandomSampler(dataset)
5253
sampler.set_epoch = Mock()
54+
if use_batch_sampler:
55+
batch_sampler = BatchSampler(sampler, 2, True)
56+
return DataLoader(dataset, batch_sampler=batch_sampler)
5357
return DataLoader(dataset, sampler=sampler)
5458

5559
model = BoringModel()
5660
trainer = Trainer(
57-
default_root_dir=tmpdir,
61+
default_root_dir=tmp_path,
5862
limit_train_batches=1,
5963
limit_val_batches=1,
6064
max_epochs=2,
@@ -66,48 +70,19 @@ def _get_dataloader():
6670
train_dataloader = _get_dataloader()
6771
val_dataloader = _get_dataloader()
6872
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
69-
# One for each epoch
70-
assert train_dataloader.sampler.set_epoch.mock_calls == [call(0), call(1)]
71-
# One for each epoch + sanity check
72-
assert val_dataloader.sampler.set_epoch.mock_calls == [call(0), call(0), call(1)]
73-
74-
val_dataloader = _get_dataloader()
75-
trainer.validate(model, val_dataloader)
76-
assert val_dataloader.sampler.set_epoch.mock_calls == [call(2)]
77-
78-
79-
def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir):
80-
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation."""
81-
82-
def _get_dataloader():
83-
dataset = RandomDataset(32, 64)
84-
sampler = RandomSampler(dataset)
85-
batch_sampler = BatchSampler(sampler, 2, True)
86-
batch_sampler.set_epoch = Mock()
87-
return DataLoader(dataset, batch_sampler=batch_sampler)
88-
89-
model = BoringModel()
90-
trainer = Trainer(
91-
default_root_dir=tmpdir,
92-
limit_train_batches=1,
93-
limit_val_batches=1,
94-
max_epochs=2,
95-
enable_model_summary=False,
96-
enable_checkpointing=False,
97-
logger=False,
98-
)
73+
train_sampler = train_dataloader.batch_sampler.sampler if use_batch_sampler else train_dataloader.sampler
74+
val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler
9975

100-
train_dataloader = _get_dataloader()
101-
val_dataloader = _get_dataloader()
102-
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
10376
# One for each epoch
104-
assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)]
77+
assert train_sampler.set_epoch.mock_calls == [call(0), call(1)]
10578
# One for each epoch + sanity check
106-
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]
79+
assert val_sampler.set_epoch.mock_calls == [call(0), call(0), call(1)]
10780

10881
val_dataloader = _get_dataloader()
10982
trainer.validate(model, val_dataloader)
110-
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)]
83+
val_sampler = val_dataloader.batch_sampler.sampler if use_batch_sampler else val_dataloader.sampler
84+
85+
assert val_sampler.set_epoch.mock_calls == [call(2)]
11186

11287

11388
@mock.patch(

tests/tests_pytorch/loops/test_prediction_loop.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from unittest import mock
16-
from unittest.mock import call
1715

1816
import pytest
17+
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
1918

2019
from lightning.pytorch import Trainer
21-
from lightning.pytorch.demos.boring_classes import BoringModel
20+
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
21+
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
2222

2323

2424
def test_prediction_loop_stores_predictions(tmp_path):
@@ -51,21 +51,39 @@ def predict_step(self, batch, batch_idx):
5151
assert trainer.predict_loop.predictions == []
5252

5353

54-
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path):
54+
@pytest.mark.parametrize("replace_sampler_ddp", (False, True))
55+
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sampler_ddp):
5556
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
56-
model = BoringModel()
5757
trainer = Trainer(
5858
default_root_dir=tmp_path,
5959
limit_predict_batches=1,
6060
enable_model_summary=False,
6161
enable_checkpointing=False,
6262
logger=False,
63+
strategy="ddp",
64+
devices=1,
65+
accelerator="cpu",
66+
replace_sampler_ddp=replace_sampler_ddp,
6367
)
64-
trainer.fit_loop.epoch_progress.current.processed = 2
6568

66-
with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock:
67-
trainer.predict(model)
68-
assert set_epoch_mock.mock_calls == [call(2)]
69+
class MyModel(BoringModel):
70+
def predict_dataloader(self):
71+
dataset = RandomDataset(32, 64)
72+
sampler = None
73+
if not replace_sampler_ddp:
74+
sampler = DistributedSampler(dataset)
75+
return DataLoader(dataset, sampler=sampler)
76+
77+
model = MyModel()
78+
trainer.fit_loop.epoch_progress.current.processed = 2
79+
trainer.predict(model)
80+
81+
# torch will set this .sampler attribute for backwards compatibility, but in reality, the batch sampler is used
82+
assert isinstance(trainer.predict_dataloaders.sampler, SequentialSampler)
83+
batch_sampler = trainer.predict_dataloaders.batch_sampler
84+
assert isinstance(batch_sampler, _IndexBatchSamplerWrapper)
85+
assert isinstance(batch_sampler.sampler, DistributedSampler)
86+
assert batch_sampler.sampler.epoch == 2
6987

7088

7189
def test_prediction_loop_with_iterable_dataset(tmp_path):

tests/tests_pytorch/loops/test_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ def test_set_sampler_epoch():
3333
dataloader = Mock()
3434
_set_sampler_epoch(dataloader, 55)
3535
dataloader.sampler.set_epoch.assert_called_once_with(55)
36-
dataloader.batch_sampler.set_epoch.assert_called_once_with(55)
36+
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)

0 commit comments

Comments
 (0)