Skip to content

Commit dfffb94

Browse files
Move predictions to CPU before accumulating (#9085)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b576201 commit dfffb94

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from collections import OrderedDict
22
from typing import Any, Dict, Iterator, List, Optional, Tuple
33

4+
import torch
45
from deprecate import void
56

67
from pytorch_lightning.loops.base import Loop
78
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
89
from pytorch_lightning.trainer.progress import Progress
10+
from pytorch_lightning.utilities.apply_func import move_data_to_device
911
from pytorch_lightning.utilities.warnings import WarningCache
1012

1113

@@ -140,7 +142,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
140142
self.batch_progress.increment_completed()
141143

142144
if self.should_store_predictions:
143-
self.predictions.append(predictions)
145+
self.predictions.append(move_data_to_device(predictions, torch.device("cpu")))
144146

145147
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]:
146148
"""

0 commit comments

Comments
 (0)