|
45 | 45 | "\n", |
46 | 46 | "from art.utils import load_nursery\n", |
47 | 47 | "\n", |
48 | | - "(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.2, transform_social=True)" |
| 48 | + "(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5, transform_social=True)" |
49 | 49 | ] |
50 | 50 | }, |
51 | 51 | { |
|
64 | 64 | "name": "stdout", |
65 | 65 | "output_type": "stream", |
66 | 66 | "text": [ |
67 | | - "Base model accuracy: 0.9791666666666666\n" |
| 67 | + "Base model accuracy: 0.9705155912318617\n" |
68 | 68 | ] |
69 | 69 | } |
70 | 70 | ], |
|
100 | 100 | "\n", |
101 | 101 | "attack_train_ratio = 0.5\n", |
102 | 102 | "attack_train_size = int(len(x_train) * attack_train_ratio)\n", |
| 103 | + "attack_test_size = int(len(x_train) * attack_train_ratio)\n", |
103 | 104 | "attack_x_train = x_train[:attack_train_size]\n", |
104 | 105 | "attack_y_train = y_train[:attack_train_size]\n", |
105 | 106 | "attack_x_test = x_train[attack_train_size:]\n", |
|
136 | 137 | "name": "stdout", |
137 | 138 | "output_type": "stream", |
138 | 139 | "text": [ |
139 | | - "0.5937861829409494\n" |
| 140 | + "0.5998765050941649\n" |
140 | 141 | ] |
141 | 142 | } |
142 | 143 | ], |
|
174 | 175 | "name": "stdout", |
175 | 176 | "output_type": "stream", |
176 | 177 | "text": [ |
177 | | - "0.6227325357005017\n" |
| 178 | + "0.6288978079654214\n" |
178 | 179 | ] |
179 | 180 | } |
180 | 181 | ], |
|
209 | 210 | "name": "stdout", |
210 | 211 | "output_type": "stream", |
211 | 212 | "text": [ |
212 | | - "0.7001157854110382\n" |
| 213 | + "0.7005248533497993\n" |
213 | 214 | ] |
214 | 215 | } |
215 | 216 | ], |
|
244 | 245 | "name": "stdout", |
245 | 246 | "output_type": "stream", |
246 | 247 | "text": [ |
247 | | - "(0.3501577287066246, 0.2573913043478261)\n", |
248 | | - "(0.34417344173441733, 0.1472463768115942)\n", |
249 | | - "(0.6309341500765697, 0.23884057971014494)\n" |
| 248 | + "(0.34232954545454547, 0.22439478584729983)\n", |
| 249 | + "(0.32320441988950277, 0.10893854748603352)\n", |
| 250 | + "(0.652046783625731, 0.20763500931098697)\n" |
250 | 251 | ] |
251 | 252 | } |
252 | 253 | ], |
|
299 | 300 | "name": "stdout", |
300 | 301 | "output_type": "stream", |
301 | 302 | "text": [ |
302 | | - "0.5372443072172907\n" |
| 303 | + "0.5433775856745909\n" |
303 | 304 | ] |
304 | 305 | } |
305 | 306 | ], |
|
344 | 345 | "\n", |
345 | 346 | "mem_attack = MembershipInferenceBlackBox(art_classifier)\n", |
346 | 347 | "\n", |
347 | | - "mem_attack.fit(x_train[:attack_train_size], y_train[:attack_train_size], x_test, y_test)" |
| 348 | + "mem_attack.fit(x_train[:attack_train_size], y_train[:attack_train_size], x_test[:attack_test_size], y_test[:attack_test_size])" |
348 | 349 | ] |
349 | 350 | }, |
350 | 351 | { |
|
356 | 357 | }, |
357 | 358 | { |
358 | 359 | "cell_type": "code", |
359 | | - "execution_count": 10, |
| 360 | + "execution_count": 11, |
360 | 361 | "metadata": {}, |
361 | 362 | "outputs": [ |
362 | 363 | { |
363 | 364 | "name": "stdout", |
364 | 365 | "output_type": "stream", |
365 | 366 | "text": [ |
366 | | - "0.6358548822848321\n" |
| 367 | + "0.6335288669342389\n" |
367 | 368 | ] |
368 | 369 | } |
369 | 370 | ], |
|
0 commit comments