Skip to content

Commit 31f7931

Browse files
authored
Merge pull request #1310 from abigailgold/dev_1.8.0_notebook_fix
Fix error in membership attack notebook
2 parents 088e0d4 + e8720aa commit 31f7931

File tree

1 file changed

+57
-36
lines changed

1 file changed

+57
-36
lines changed

notebooks/attack_membership_inference.ipynb

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
},
3131
{
3232
"cell_type": "code",
33-
"execution_count": 10,
33+
"execution_count": 1,
3434
"metadata": {},
3535
"outputs": [],
3636
"source": [
@@ -52,14 +52,14 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": 11,
55+
"execution_count": 2,
5656
"metadata": {},
5757
"outputs": [
5858
{
5959
"name": "stdout",
6060
"output_type": "stream",
6161
"text": [
62-
"Base model accuracy: 0.9731398579808583\n"
62+
"Base model accuracy: 0.9739117011423278\n"
6363
]
6464
}
6565
],
@@ -86,16 +86,16 @@
8686
},
8787
{
8888
"cell_type": "code",
89-
"execution_count": 12,
89+
"execution_count": 3,
9090
"metadata": {},
9191
"outputs": [
9292
{
9393
"name": "stdout",
9494
"output_type": "stream",
9595
"text": [
9696
"1.0\n",
97-
"0.026860142019141664\n",
98-
"0.5134300710095708\n"
97+
"0.026088298857672165\n",
98+
"0.513044149428836\n"
9999
]
100100
}
101101
],
@@ -127,14 +127,14 @@
127127
},
128128
{
129129
"cell_type": "code",
130-
"execution_count": 13,
130+
"execution_count": 4,
131131
"metadata": {},
132132
"outputs": [
133133
{
134134
"name": "stdout",
135135
"output_type": "stream",
136136
"text": [
137-
"(0.5068064465654827, 1.0)\n"
137+
"(0.50660827402831, 1.0)\n"
138138
]
139139
}
140140
],
@@ -179,7 +179,7 @@
179179
},
180180
{
181181
"cell_type": "code",
182-
"execution_count": 14,
182+
"execution_count": 5,
183183
"metadata": {},
184184
"outputs": [],
185185
"source": [
@@ -205,16 +205,18 @@
205205
},
206206
{
207207
"cell_type": "code",
208-
"execution_count": 15,
209-
"metadata": {},
208+
"execution_count": 6,
209+
"metadata": {
210+
"scrolled": true
211+
},
210212
"outputs": [
211213
{
212214
"name": "stdout",
213215
"output_type": "stream",
214216
"text": [
215-
"0.7313985798085829\n",
216-
"0.5541833899351651\n",
217-
"0.642790984871874\n"
217+
"0.6736647113306576\n",
218+
"0.6375424513738808\n",
219+
"0.6556035813522693\n"
218220
]
219221
}
220222
],
@@ -235,19 +237,19 @@
235237
"cell_type": "markdown",
236238
"metadata": {},
237239
"source": [
238-
"Acheives slightly better results than the rule-based attack."
240+
"Acheives much better results than the rule-based attack."
239241
]
240242
},
241243
{
242244
"cell_type": "code",
243-
"execution_count": 16,
245+
"execution_count": 7,
244246
"metadata": {},
245247
"outputs": [
246248
{
247249
"name": "stdout",
248250
"output_type": "stream",
249251
"text": [
250-
"(0.6212955677943877, 0.7313985798085829)\n"
252+
"(0.6501787842669845, 0.6736647113306576)\n"
251253
]
252254
}
253255
],
@@ -266,14 +268,14 @@
266268
},
267269
{
268270
"cell_type": "code",
269-
"execution_count": 17,
271+
"execution_count": 11,
270272
"metadata": {},
271273
"outputs": [
272274
{
273275
"name": "stdout",
274276
"output_type": "stream",
275277
"text": [
276-
"Base model accuracy: 0.9668107440568077\n"
278+
"Base model accuracy: 0.926\n"
277279
]
278280
}
279281
],
@@ -284,6 +286,15 @@
284286
"from torch.utils.data.dataset import Dataset\n",
285287
"from art.estimators.classification.pytorch import PyTorchClassifier\n",
286288
"\n",
289+
"# reduce size of training set to make attack slightly better\n",
290+
"train_set_size = 500\n",
291+
"x_train = x_train[:train_set_size]\n",
292+
"y_train = y_train[:train_set_size]\n",
293+
"x_test = x_test[:train_set_size]\n",
294+
"y_test = y_test[:train_set_size]\n",
295+
"attack_train_size = int(len(x_train) * attack_train_ratio)\n",
296+
"attack_test_size = int(len(x_test) * attack_train_ratio)\n",
297+
"\n",
287298
"class ModelToAttack(nn.Module):\n",
288299
"\n",
289300
" def __init__(self, num_classes, num_features):\n",
@@ -297,18 +308,28 @@
297308
" nn.Linear(1024, 512),\n",
298309
" nn.Tanh(), )\n",
299310
"\n",
300-
" self.classifier = nn.Linear(512, num_classes)\n",
301-
" # self.softmax = nn.Softmax(dim=1)\n",
311+
" self.fc3 = nn.Sequential(\n",
312+
" nn.Linear(512, 256),\n",
313+
" nn.Tanh(), )\n",
314+
" \n",
315+
" self.fc4 = nn.Sequential(\n",
316+
" nn.Linear(256, 128),\n",
317+
" nn.Tanh(),\n",
318+
" )\n",
319+
"\n",
320+
" self.classifier = nn.Linear(128, num_classes)\n",
302321
"\n",
303322
" def forward(self, x):\n",
304323
" out = self.fc1(x)\n",
305324
" out = self.fc2(out)\n",
325+
" out = self.fc3(out)\n",
326+
" out = self.fc4(out)\n",
306327
" return self.classifier(out)\n",
307328
"\n",
308329
"mlp_model = ModelToAttack(4, 24)\n",
309330
"mlp_model = torch.nn.DataParallel(mlp_model)\n",
310331
"criterion = nn.CrossEntropyLoss()\n",
311-
"optimizer = optim.Adam(mlp_model.parameters(), lr=0.0001)\n",
332+
"optimizer = optim.Adam(mlp_model.parameters(), lr=0.01)\n",
312333
"\n",
313334
"class NurseryDataset(Dataset):\n",
314335
" def __init__(self, x, y=None):\n",
@@ -358,17 +379,17 @@
358379
},
359380
{
360381
"cell_type": "code",
361-
"execution_count": 18,
382+
"execution_count": 12,
362383
"metadata": {},
363384
"outputs": [
364385
{
365386
"name": "stdout",
366387
"output_type": "stream",
367388
"text": [
368-
"0.9786971287434393\n",
369-
"0.03318925594319233\n",
370-
"0.5059431923433159\n",
371-
"(0.5030548282155043, 0.9786971287434393)\n"
389+
"0.998\n",
390+
"0.07399999999999995\n",
391+
"0.536\n",
392+
"(0.5187110187110187, 0.998)\n"
372393
]
373394
}
374395
],
@@ -400,17 +421,17 @@
400421
},
401422
{
402423
"cell_type": "code",
403-
"execution_count": 19,
424+
"execution_count": 13,
404425
"metadata": {},
405426
"outputs": [
406427
{
407428
"name": "stdout",
408429
"output_type": "stream",
409430
"text": [
410-
"0.7488422352577956\n",
411-
"0.7472985489348565\n",
412-
"0.748070392096326\n",
413-
"(0.7476880394574599, 0.7488422352577956)\n"
431+
"0.608\n",
432+
"0.5680000000000001\n",
433+
"0.588\n",
434+
"(0.5846153846153846, 0.608)\n"
414435
]
415436
}
416437
],
@@ -422,8 +443,8 @@
422443
" x_test[:attack_test_size].astype(np.float32), y_test[:attack_test_size])\n",
423444
"\n",
424445
"# infer \n",
425-
"mlp_inferred_train_bb = mlp_attack_bb.infer(x_train.astype(np.float32), y_train)\n",
426-
"mlp_inferred_test_bb = mlp_attack_bb.infer(x_test.astype(np.float32), y_test)\n",
446+
"mlp_inferred_train_bb = mlp_attack_bb.infer(x_train[attack_train_size:].astype(np.float32), y_train[attack_train_size:])\n",
447+
"mlp_inferred_test_bb = mlp_attack_bb.infer(x_test[attack_test_size:].astype(np.float32), y_test[attack_test_size:])\n",
427448
"\n",
428449
"# check accuracy\n",
429450
"mlp_train_acc_bb = np.sum(mlp_inferred_train_bb) / len(mlp_inferred_train_bb)\n",
@@ -441,7 +462,7 @@
441462
"cell_type": "markdown",
442463
"metadata": {},
443464
"source": [
444-
"Using a random forest as the attack model we were able to acheive better performance than the rule-based attack, both in terms of accuracy and precision."
465+
"For the pytorch target model we were able to acheive slightly better than random attack performance, but not as good as for the random forest model."
445466
]
446467
}
447468
],
@@ -461,7 +482,7 @@
461482
"name": "python",
462483
"nbconvert_exporter": "python",
463484
"pygments_lexer": "ipython3",
464-
"version": "3.7.1"
485+
"version": "3.8.3"
465486
}
466487
},
467488
"nbformat": 4,

0 commit comments

Comments
 (0)