|
34 | 34 | "\n", |
35 | 35 | "import numpy as np\n", |
36 | 36 | "from torchvision import datasets\n", |
37 | | - "from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch" |
| 37 | + "from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch\n", |
| 38 | + "\n", |
| 39 | + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" |
38 | 40 | ] |
39 | 41 | }, |
40 | 42 | { |
|
86 | 88 | " model = transformers.AutoModelForImageClassification.from_pretrained(architecture,\n", |
87 | 89 | " ignore_mismatched_sizes=True,\n", |
88 | 90 | " num_labels=10)\n", |
89 | | - "\n", |
| 91 | + " model = model.to(device)\n", |
90 | 92 | " # The HuggingFaceClassifierPyTorch follows broadly the same API as the PyTorchClassifier\n", |
91 | 93 | " # So we can supply the loss function, the input shape of the data we will supply, the optimizer, etc.\n", |
92 | 94 | " # Note, frequently we will be performing fine-tuning or transfer learning with vision transformners and \n", |
|
533 | 535 | " import timm\n", |
534 | 536 | " \n", |
535 | 537 | " model = timm.create_model('resnet50.a1_in1k', pretrained=True)\n", |
| 538 | + " model = model.to(device)\n", |
536 | 539 | " upsampler = torch.nn.Upsample(scale_factor=7, mode='nearest')\n", |
537 | 540 | " \n", |
538 | 541 | " optimizer = Adam(model.parameters(), lr=1e-3)\n", |
|
816 | 819 | "name": "python", |
817 | 820 | "nbconvert_exporter": "python", |
818 | 821 | "pygments_lexer": "ipython3", |
819 | | - "version": "3.11.3" |
| 822 | + "version": "3.8.10" |
820 | 823 | } |
821 | 824 | }, |
822 | 825 | "nbformat": 4, |
|
0 commit comments