|
93 | 93 | "id": "69e3ffe8", |
94 | 94 | "metadata": {}, |
95 | 95 | "outputs": [], |
96 | | - "source": [ |
97 | | - "# num_classes = 10\n", |
98 | | - "# loss_fn = nn.CrossEntropyLoss()\n", |
99 | | - "# model = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n", |
100 | | - "# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n", |
101 | | - "# model_art = PyTorchClassifier(model,input_shape=x_train.shape[1:], loss=loss_fn, optimizer=optimizer, nb_classes=10, clip_values=(min_, max_), preprocessing=(mean,std))\n", |
102 | | - "# model_art.fit(x_train, y_train, batch_size=128, nb_epochs=80,verbose=0)\n", |
103 | | - "# predictions = model_art.predict(x_test)\n", |
104 | | - "# accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n", |
105 | | - "# print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))" |
106 | | - ] |
107 | | - }, |
108 | | - { |
109 | | - "cell_type": "code", |
110 | | - "execution_count": 5, |
111 | | - "id": "3b807d14", |
112 | | - "metadata": {}, |
113 | | - "outputs": [ |
114 | | - { |
115 | | - "name": "stdout", |
116 | | - "output_type": "stream", |
117 | | - "text": [ |
118 | | - "Accuracy on benign test examples: 73.95%\n" |
119 | | - ] |
120 | | - } |
121 | | - ], |
122 | 96 | "source": [ |
123 | 97 | "num_classes = 10\n", |
124 | 98 | "loss_fn = nn.CrossEntropyLoss()\n", |
125 | | - "model = torch.load('model.pt')\n", |
| 99 | + "model = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n", |
126 | 100 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n", |
127 | 101 | "model_art = PyTorchClassifier(model,input_shape=x_train.shape[1:], loss=loss_fn, optimizer=optimizer, nb_classes=10, clip_values=(min_, max_), preprocessing=(mean,std))\n", |
| 102 | + "model_art.fit(x_train, y_train, batch_size=128, nb_epochs=80,verbose=0)\n", |
128 | 103 | "predictions = model_art.predict(x_test)\n", |
129 | 104 | "accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n", |
130 | 105 | "print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))" |
131 | 106 | ] |
132 | 107 | }, |
133 | | - { |
134 | | - "cell_type": "code", |
135 | | - "execution_count": 6, |
136 | | - "id": "d973a78c", |
137 | | - "metadata": {}, |
138 | | - "outputs": [ |
139 | | - { |
140 | | - "data": { |
141 | | - "text/plain": [ |
142 | | - "False" |
143 | | - ] |
144 | | - }, |
145 | | - "execution_count": 6, |
146 | | - "metadata": {}, |
147 | | - "output_type": "execute_result" |
148 | | - } |
149 | | - ], |
150 | | - "source": [ |
151 | | - "model_art.model.training" |
152 | | - ] |
153 | | - }, |
154 | 108 | { |
155 | 109 | "cell_type": "markdown", |
156 | 110 | "id": "9b1ca858", |
|
601 | 555 | "name": "python", |
602 | 556 | "nbconvert_exporter": "python", |
603 | 557 | "pygments_lexer": "ipython3", |
604 | | - "version": "3.8.12" |
| 558 | + "version": "3.9.9" |
605 | 559 | }, |
606 | 560 | "vscode": { |
607 | 561 | "interpreter": { |
|
0 commit comments