Skip to content

Commit 0449e86

Browse files
Fix trainer.predict(return_predictions=False) does not track batch_indices (#13629)
* Pull request for fixing issue #13580 * chlog and test * disable track for epoch Co-authored-by: rohitgr7 <[email protected]>
1 parent d18f45b commit 0449e86

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
342342
- Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417))
343343

344344

345+
- Fixed `Trainer.predict(return_predictions=False)` to track prediction's batch_indices ([#13629](https://github.com/Lightning-AI/lightning/pull/13629))
346+
347+
345348
## [1.6.5] - 2022-07-13
346349

347350
### Fixed

src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def on_run_start( # type: ignore[override]
6767
self._dl_max_batches = dl_max_batches
6868
self._num_dataloaders = num_dataloaders
6969
# this call requires that `self.return_predictions` is set
70-
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
70+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx) if self.should_store_predictions else []
7171

7272
def advance( # type: ignore[override]
7373
self,
@@ -87,7 +87,7 @@ def advance( # type: ignore[override]
8787
action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
8888
with self.trainer.profiler.profile(action_name):
8989
batch_idx, batch = next(dataloader_iter)
90-
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
90+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx) if self.should_store_predictions else []
9191
# we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning
9292
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]
9393

@@ -119,7 +119,8 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
119119
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)
120120

121121
# extract batch_indices and store them
122-
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []
122+
batch_indices = self._get_batch_indices(dataloader_idx)
123+
self.current_batch_indices = batch_indices[batch_idx] if batch_indices else []
123124

124125
self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx)
125126
self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
@@ -166,7 +167,7 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
166167
"batch_sampler",
167168
None,
168169
)
169-
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
170+
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
170171
return batch_sampler.seen_batch_indices
171172

172173
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")

tests/tests_pytorch/callbacks/test_prediction_writer.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626

2727

2828
class DummyPredictionWriter(BasePredictionWriter):
29-
def write_on_batch_end(self, *args, **kwargs):
29+
def write_on_batch_end(self, *_, **__):
3030
pass
3131

32-
def write_on_epoch_end(self, *args, **kwargs):
32+
def write_on_epoch_end(self, *_, **__):
3333
pass
3434

3535

@@ -39,7 +39,7 @@ def test_prediction_writer_invalid_write_interval():
3939
DummyPredictionWriter("something")
4040

4141

42-
def test_prediction_writer_hook_call_intervals(tmpdir):
42+
def test_prediction_writer_hook_call_intervals():
4343
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
4444
interval."""
4545
DummyPredictionWriter.write_on_batch_end = Mock()
@@ -84,7 +84,7 @@ def test_prediction_writer_hook_call_intervals(tmpdir):
8484

8585

8686
@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
87-
def test_prediction_writer_batch_indices(tmpdir, num_workers):
87+
def test_prediction_writer_batch_indices(num_workers):
8888
DummyPredictionWriter.write_on_batch_end = Mock()
8989
DummyPredictionWriter.write_on_epoch_end = Mock()
9090

@@ -110,7 +110,7 @@ def test_prediction_writer_batch_indices(tmpdir, num_workers):
110110
)
111111

112112

113-
def test_prediction_writer_partial_support_for_combined_loader(tmpdir):
113+
def test_prediction_writer_partial_support_for_combined_loader():
114114
"""Test partial support for CombinedLoader: prediction works but sample indices don't get tracked."""
115115
pl.loops.epoch.prediction_epoch_loop.warning_cache.clear()
116116

@@ -140,3 +140,27 @@ def predict_step(self, batch, *args, **kwargs):
140140
)
141141

142142
writer.write_on_epoch_end.assert_has_calls([call(trainer, model, ANY, [[]])])
143+
144+
145+
def test_batch_level_batch_indices():
146+
"""Test that batch_indices are returned when `return_predictions=False`."""
147+
DummyPredictionWriter.write_on_batch_end = Mock()
148+
149+
class CustomBoringModel(BoringModel):
150+
def on_predict_epoch_end(self, *args, **kwargs):
151+
assert self.trainer.predict_loop.epoch_batch_indices == [[]]
152+
153+
writer = DummyPredictionWriter("batch")
154+
model = CustomBoringModel()
155+
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4)
156+
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
157+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
158+
159+
writer.write_on_batch_end.assert_has_calls(
160+
[
161+
call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
162+
call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
163+
call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
164+
call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
165+
]
166+
)

0 commit comments

Comments
 (0)