Skip to content

Commit 430af84

Browse files
committed
updating notebook to confirm fix
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 4f2a479 commit 430af84

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

notebooks/huggingface_notebook.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
"\n",
3535
"import numpy as np\n",
3636
"from torchvision import datasets\n",
37-
"from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch"
37+
"from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch\n",
38+
"\n",
39+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
3840
]
3941
},
4042
{
@@ -86,7 +88,7 @@
8688
" model = transformers.AutoModelForImageClassification.from_pretrained(architecture,\n",
8789
" ignore_mismatched_sizes=True,\n",
8890
" num_labels=10)\n",
89-
"\n",
91+
" model = model.to(device)\n",
9092
" # The HuggingFaceClassifierPyTorch follows broadly the same API as the PyTorchClassifier\n",
9193
" # So we can supply the loss function, the input shape of the data we will supply, the optimizer, etc.\n",
9294
" # Note, frequently we will be performing fine-tuning or transfer learning with vision transformners and \n",
@@ -533,6 +535,7 @@
533535
" import timm\n",
534536
" \n",
535537
" model = timm.create_model('resnet50.a1_in1k', pretrained=True)\n",
538+
" model = model.to(device)\n",
536539
" upsampler = torch.nn.Upsample(scale_factor=7, mode='nearest')\n",
537540
" \n",
538541
" optimizer = Adam(model.parameters(), lr=1e-3)\n",
@@ -816,7 +819,7 @@
816819
"name": "python",
817820
"nbconvert_exporter": "python",
818821
"pygments_lexer": "ipython3",
819-
"version": "3.11.3"
822+
"version": "3.8.10"
820823
}
821824
},
822825
"nbformat": 4,

0 commit comments

Comments
 (0)