diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 7044ccea87a7f..dcfd873a28b4b 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -233,8 +233,9 @@ def _predict_step( self.batch_progress.increment_ready() - if not using_dataloader_iter: - any_on_epoch = self._store_data_for_prediction_writer(batch_idx, dataloader_idx) + any_on_epoch = ( + self._store_data_for_prediction_writer(batch_idx, dataloader_idx) if not using_dataloader_iter else False + ) # the `_step` methods don't take a batch_idx when `dataloader_iter` is used, but all other hooks still do, # so we need different kwargs