Skip to content

Commit 4f2a479

Browse files
committed
move hook input to original model device
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 337a15f commit 4f2a479

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

art/estimators/classification/hugging_face.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _make_model_wrapper(self, model: "torch.nn.Module") -> "torch.nn.Module":
155155
import torch
156156

157157
input_shape = self._input_shape
158-
input_for_hook = torch.rand(input_shape)
158+
input_for_hook = torch.rand(input_shape).to(self.device)
159159
input_for_hook = torch.unsqueeze(input_for_hook, dim=0)
160160

161161
if self.processor is not None:

0 commit comments

Comments
 (0)