Skip to content

Commit c2d333f

Browse files
committed
get device model is running on to move hook input onto
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 430af84 commit c2d333f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

art/estimators/classification/hugging_face.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,13 @@ def _make_model_wrapper(self, model: "torch.nn.Module") -> "torch.nn.Module":
155155
import torch
156156

157157
input_shape = self._input_shape
158-
input_for_hook = torch.rand(input_shape).to(self.device)
158+
input_for_hook = torch.rand(input_shape)
159+
# self.device may not match the device the raw model was passed into ART.
160+
# Check if the model is on cuda, if so set the hook input accordingly
161+
if next(model.parameters()).is_cuda:
162+
cuda_idx = torch.cuda.current_device()
163+
input_for_hook = input_for_hook.to(torch.device(f"cuda:{cuda_idx}"))
164+
159165
input_for_hook = torch.unsqueeze(input_for_hook, dim=0)
160166

161167
if self.processor is not None:

0 commit comments

Comments
 (0)