Skip to content

Commit 897c4da

Browse files
committed
Updating pytorch notebook
Signed-off-by: Shriti Priya <[email protected]>
1 parent c1262ac commit 897c4da

File tree

1 file changed

+3
-49
lines changed

1 file changed

+3
-49
lines changed

notebooks/poisoning_attack_sleeper_agent_pytorch.ipynb

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -93,64 +93,18 @@
9393
"id": "69e3ffe8",
9494
"metadata": {},
9595
"outputs": [],
96-
"source": [
97-
"# num_classes = 10\n",
98-
"# loss_fn = nn.CrossEntropyLoss()\n",
99-
"# model = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n",
100-
"# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n",
101-
"# model_art = PyTorchClassifier(model,input_shape=x_train.shape[1:], loss=loss_fn, optimizer=optimizer, nb_classes=10, clip_values=(min_, max_), preprocessing=(mean,std))\n",
102-
"# model_art.fit(x_train, y_train, batch_size=128, nb_epochs=80,verbose=0)\n",
103-
"# predictions = model_art.predict(x_test)\n",
104-
"# accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
105-
"# print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))"
106-
]
107-
},
108-
{
109-
"cell_type": "code",
110-
"execution_count": 5,
111-
"id": "3b807d14",
112-
"metadata": {},
113-
"outputs": [
114-
{
115-
"name": "stdout",
116-
"output_type": "stream",
117-
"text": [
118-
"Accuracy on benign test examples: 73.95%\n"
119-
]
120-
}
121-
],
12296
"source": [
12397
"num_classes = 10\n",
12498
"loss_fn = nn.CrossEntropyLoss()\n",
125-
"model = torch.load('model.pt')\n",
99+
"model = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)\n",
126100
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)\n",
127101
"model_art = PyTorchClassifier(model,input_shape=x_train.shape[1:], loss=loss_fn, optimizer=optimizer, nb_classes=10, clip_values=(min_, max_), preprocessing=(mean,std))\n",
102+
"model_art.fit(x_train, y_train, batch_size=128, nb_epochs=80,verbose=0)\n",
128103
"predictions = model_art.predict(x_test)\n",
129104
"accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
130105
"print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))"
131106
]
132107
},
133-
{
134-
"cell_type": "code",
135-
"execution_count": 6,
136-
"id": "d973a78c",
137-
"metadata": {},
138-
"outputs": [
139-
{
140-
"data": {
141-
"text/plain": [
142-
"False"
143-
]
144-
},
145-
"execution_count": 6,
146-
"metadata": {},
147-
"output_type": "execute_result"
148-
}
149-
],
150-
"source": [
151-
"model_art.model.training"
152-
]
153-
},
154108
{
155109
"cell_type": "markdown",
156110
"id": "9b1ca858",
@@ -601,7 +555,7 @@
601555
"name": "python",
602556
"nbconvert_exporter": "python",
603557
"pygments_lexer": "ipython3",
604-
"version": "3.8.12"
558+
"version": "3.9.9"
605559
},
606560
"vscode": {
607561
"interpreter": {

0 commit comments

Comments
 (0)