Skip to content

Commit dec5ddf

Browse files
authored
Merge pull request #2300 from GiulioZizzo/hf_model_wrapper_update
Huggingface model wrapper update
2 parents 1b3120b + e896821 commit dec5ddf

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

art/estimators/classification/hugging_face.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def _make_model_wrapper(self, model: "torch.nn.Module") -> "torch.nn.Module":
156156

157157
input_shape = self._input_shape
158158
input_for_hook = torch.rand(input_shape)
159+
# self.device may not match the device the raw model was passed into ART.
160+
# Check if the model is on cuda, if so set the hook input accordingly
161+
if next(model.parameters()).is_cuda:
162+
cuda_idx = torch.cuda.current_device()
163+
input_for_hook = input_for_hook.to(torch.device(f"cuda:{cuda_idx}"))
164+
159165
input_for_hook = torch.unsqueeze(input_for_hook, dim=0)
160166

161167
if self.processor is not None:

notebooks/huggingface_notebook.ipynb

Lines changed: 19 additions & 17 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)