|
30 | 30 | }, |
31 | 31 | { |
32 | 32 | "cell_type": "code", |
33 | | - "execution_count": 10, |
| 33 | + "execution_count": 1, |
34 | 34 | "metadata": {}, |
35 | 35 | "outputs": [], |
36 | 36 | "source": [ |
|
52 | 52 | }, |
53 | 53 | { |
54 | 54 | "cell_type": "code", |
55 | | - "execution_count": 11, |
| 55 | + "execution_count": 2, |
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [ |
58 | 58 | { |
59 | 59 | "name": "stdout", |
60 | 60 | "output_type": "stream", |
61 | 61 | "text": [ |
62 | | - "Base model accuracy: 0.9731398579808583\n" |
| 62 | + "Base model accuracy: 0.9739117011423278\n" |
63 | 63 | ] |
64 | 64 | } |
65 | 65 | ], |
|
86 | 86 | }, |
87 | 87 | { |
88 | 88 | "cell_type": "code", |
89 | | - "execution_count": 12, |
| 89 | + "execution_count": 3, |
90 | 90 | "metadata": {}, |
91 | 91 | "outputs": [ |
92 | 92 | { |
93 | 93 | "name": "stdout", |
94 | 94 | "output_type": "stream", |
95 | 95 | "text": [ |
96 | 96 | "1.0\n", |
97 | | - "0.026860142019141664\n", |
98 | | - "0.5134300710095708\n" |
| 97 | + "0.026088298857672165\n", |
| 98 | + "0.513044149428836\n" |
99 | 99 | ] |
100 | 100 | } |
101 | 101 | ], |
|
127 | 127 | }, |
128 | 128 | { |
129 | 129 | "cell_type": "code", |
130 | | - "execution_count": 13, |
| 130 | + "execution_count": 4, |
131 | 131 | "metadata": {}, |
132 | 132 | "outputs": [ |
133 | 133 | { |
134 | 134 | "name": "stdout", |
135 | 135 | "output_type": "stream", |
136 | 136 | "text": [ |
137 | | - "(0.5068064465654827, 1.0)\n" |
| 137 | + "(0.50660827402831, 1.0)\n" |
138 | 138 | ] |
139 | 139 | } |
140 | 140 | ], |
|
179 | 179 | }, |
180 | 180 | { |
181 | 181 | "cell_type": "code", |
182 | | - "execution_count": 14, |
| 182 | + "execution_count": 5, |
183 | 183 | "metadata": {}, |
184 | 184 | "outputs": [], |
185 | 185 | "source": [ |
|
205 | 205 | }, |
206 | 206 | { |
207 | 207 | "cell_type": "code", |
208 | | - "execution_count": 15, |
209 | | - "metadata": {}, |
| 208 | + "execution_count": 6, |
| 209 | + "metadata": { |
| 210 | + "scrolled": true |
| 211 | + }, |
210 | 212 | "outputs": [ |
211 | 213 | { |
212 | 214 | "name": "stdout", |
213 | 215 | "output_type": "stream", |
214 | 216 | "text": [ |
215 | | - "0.7313985798085829\n", |
216 | | - "0.5541833899351651\n", |
217 | | - "0.642790984871874\n" |
| 217 | + "0.6736647113306576\n", |
| 218 | + "0.6375424513738808\n", |
| 219 | + "0.6556035813522693\n" |
218 | 220 | ] |
219 | 221 | } |
220 | 222 | ], |
|
235 | 237 | "cell_type": "markdown", |
236 | 238 | "metadata": {}, |
237 | 239 | "source": [ |
238 | | - "Acheives slightly better results than the rule-based attack." |
| 240 | + "Acheives much better results than the rule-based attack." |
239 | 241 | ] |
240 | 242 | }, |
241 | 243 | { |
242 | 244 | "cell_type": "code", |
243 | | - "execution_count": 16, |
| 245 | + "execution_count": 7, |
244 | 246 | "metadata": {}, |
245 | 247 | "outputs": [ |
246 | 248 | { |
247 | 249 | "name": "stdout", |
248 | 250 | "output_type": "stream", |
249 | 251 | "text": [ |
250 | | - "(0.6212955677943877, 0.7313985798085829)\n" |
| 252 | + "(0.6501787842669845, 0.6736647113306576)\n" |
251 | 253 | ] |
252 | 254 | } |
253 | 255 | ], |
|
266 | 268 | }, |
267 | 269 | { |
268 | 270 | "cell_type": "code", |
269 | | - "execution_count": 17, |
| 271 | + "execution_count": 11, |
270 | 272 | "metadata": {}, |
271 | 273 | "outputs": [ |
272 | 274 | { |
273 | 275 | "name": "stdout", |
274 | 276 | "output_type": "stream", |
275 | 277 | "text": [ |
276 | | - "Base model accuracy: 0.9668107440568077\n" |
| 278 | + "Base model accuracy: 0.926\n" |
277 | 279 | ] |
278 | 280 | } |
279 | 281 | ], |
|
284 | 286 | "from torch.utils.data.dataset import Dataset\n", |
285 | 287 | "from art.estimators.classification.pytorch import PyTorchClassifier\n", |
286 | 288 | "\n", |
| 289 | + "# reduce size of training set to make attack slightly better\n", |
| 290 | + "train_set_size = 500\n", |
| 291 | + "x_train = x_train[:train_set_size]\n", |
| 292 | + "y_train = y_train[:train_set_size]\n", |
| 293 | + "x_test = x_test[:train_set_size]\n", |
| 294 | + "y_test = y_test[:train_set_size]\n", |
| 295 | + "attack_train_size = int(len(x_train) * attack_train_ratio)\n", |
| 296 | + "attack_test_size = int(len(x_test) * attack_train_ratio)\n", |
| 297 | + "\n", |
287 | 298 | "class ModelToAttack(nn.Module):\n", |
288 | 299 | "\n", |
289 | 300 | " def __init__(self, num_classes, num_features):\n", |
|
297 | 308 | " nn.Linear(1024, 512),\n", |
298 | 309 | " nn.Tanh(), )\n", |
299 | 310 | "\n", |
300 | | - " self.classifier = nn.Linear(512, num_classes)\n", |
301 | | - " # self.softmax = nn.Softmax(dim=1)\n", |
| 311 | + " self.fc3 = nn.Sequential(\n", |
| 312 | + " nn.Linear(512, 256),\n", |
| 313 | + " nn.Tanh(), )\n", |
| 314 | + " \n", |
| 315 | + " self.fc4 = nn.Sequential(\n", |
| 316 | + " nn.Linear(256, 128),\n", |
| 317 | + " nn.Tanh(),\n", |
| 318 | + " )\n", |
| 319 | + "\n", |
| 320 | + " self.classifier = nn.Linear(128, num_classes)\n", |
302 | 321 | "\n", |
303 | 322 | " def forward(self, x):\n", |
304 | 323 | " out = self.fc1(x)\n", |
305 | 324 | " out = self.fc2(out)\n", |
| 325 | + " out = self.fc3(out)\n", |
| 326 | + " out = self.fc4(out)\n", |
306 | 327 | " return self.classifier(out)\n", |
307 | 328 | "\n", |
308 | 329 | "mlp_model = ModelToAttack(4, 24)\n", |
309 | 330 | "mlp_model = torch.nn.DataParallel(mlp_model)\n", |
310 | 331 | "criterion = nn.CrossEntropyLoss()\n", |
311 | | - "optimizer = optim.Adam(mlp_model.parameters(), lr=0.0001)\n", |
| 332 | + "optimizer = optim.Adam(mlp_model.parameters(), lr=0.01)\n", |
312 | 333 | "\n", |
313 | 334 | "class NurseryDataset(Dataset):\n", |
314 | 335 | " def __init__(self, x, y=None):\n", |
|
358 | 379 | }, |
359 | 380 | { |
360 | 381 | "cell_type": "code", |
361 | | - "execution_count": 18, |
| 382 | + "execution_count": 12, |
362 | 383 | "metadata": {}, |
363 | 384 | "outputs": [ |
364 | 385 | { |
365 | 386 | "name": "stdout", |
366 | 387 | "output_type": "stream", |
367 | 388 | "text": [ |
368 | | - "0.9786971287434393\n", |
369 | | - "0.03318925594319233\n", |
370 | | - "0.5059431923433159\n", |
371 | | - "(0.5030548282155043, 0.9786971287434393)\n" |
| 389 | + "0.998\n", |
| 390 | + "0.07399999999999995\n", |
| 391 | + "0.536\n", |
| 392 | + "(0.5187110187110187, 0.998)\n" |
372 | 393 | ] |
373 | 394 | } |
374 | 395 | ], |
|
400 | 421 | }, |
401 | 422 | { |
402 | 423 | "cell_type": "code", |
403 | | - "execution_count": 19, |
| 424 | + "execution_count": 13, |
404 | 425 | "metadata": {}, |
405 | 426 | "outputs": [ |
406 | 427 | { |
407 | 428 | "name": "stdout", |
408 | 429 | "output_type": "stream", |
409 | 430 | "text": [ |
410 | | - "0.7488422352577956\n", |
411 | | - "0.7472985489348565\n", |
412 | | - "0.748070392096326\n", |
413 | | - "(0.7476880394574599, 0.7488422352577956)\n" |
| 431 | + "0.608\n", |
| 432 | + "0.5680000000000001\n", |
| 433 | + "0.588\n", |
| 434 | + "(0.5846153846153846, 0.608)\n" |
414 | 435 | ] |
415 | 436 | } |
416 | 437 | ], |
|
422 | 443 | " x_test[:attack_test_size].astype(np.float32), y_test[:attack_test_size])\n", |
423 | 444 | "\n", |
424 | 445 | "# infer \n", |
425 | | - "mlp_inferred_train_bb = mlp_attack_bb.infer(x_train.astype(np.float32), y_train)\n", |
426 | | - "mlp_inferred_test_bb = mlp_attack_bb.infer(x_test.astype(np.float32), y_test)\n", |
| 446 | + "mlp_inferred_train_bb = mlp_attack_bb.infer(x_train[attack_train_size:].astype(np.float32), y_train[attack_train_size:])\n", |
| 447 | + "mlp_inferred_test_bb = mlp_attack_bb.infer(x_test[attack_test_size:].astype(np.float32), y_test[attack_test_size:])\n", |
427 | 448 | "\n", |
428 | 449 | "# check accuracy\n", |
429 | 450 | "mlp_train_acc_bb = np.sum(mlp_inferred_train_bb) / len(mlp_inferred_train_bb)\n", |
|
441 | 462 | "cell_type": "markdown", |
442 | 463 | "metadata": {}, |
443 | 464 | "source": [ |
444 | | - "Using a random forest as the attack model we were able to acheive better performance than the rule-based attack, both in terms of accuracy and precision." |
| 465 | + "For the pytorch target model we were able to acheive slightly better than random attack performance, but not as good as for the random forest model." |
445 | 466 | ] |
446 | 467 | } |
447 | 468 | ], |
|
461 | 482 | "name": "python", |
462 | 483 | "nbconvert_exporter": "python", |
463 | 484 | "pygments_lexer": "ipython3", |
464 | | - "version": "3.7.1" |
| 485 | + "version": "3.8.3" |
465 | 486 | } |
466 | 487 | }, |
467 | 488 | "nbformat": 4, |
|
0 commit comments