Skip to content

Commit 3b7122b

Browse files
committed
Adding reconstruction attack notebook
1 parent 00f002d commit 3b7122b

File tree

1 file changed

+346
-0
lines changed

1 file changed

+346
-0
lines changed
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Running database reconstruction attacks on the Iris dataset"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"In this tutorial we will show how to run a database reconstruction attack on the Iris dataset."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {},
20+
"source": [
21+
"## Preliminaries"
22+
]
23+
},
24+
{
25+
"cell_type": "markdown",
26+
"metadata": {},
27+
"source": [
28+
"The database reconstruction attack takes a trained machine learning model `model`, which has been trained by a training dataset of `n` examples. Then, using `n-1` examples of the training dataset (i.e., with the target row removed), we seek to reconstruct the `n`th example of the dataset by using `model`.\n",
29+
"\n",
30+
"In this example, we train a Gaussian Naive Bayes classifier (`model`) with the training dataset, then remove a single row from that dataset, and seek to reconstruct that row using `model`. For typical examples, this attack is successful up to machine precision.\n",
31+
"\n",
32+
"We then show that launching the same attack on a ML model trained with differential privacy guarantees provides protection for the traning dataset, and prevents learning the target row with precision."
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"## Example usage"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"## Load data"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"metadata": {},
52+
"source": [
53+
"First, we load the data of interest and split into train/test subsets. "
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": 1,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"from sklearn import datasets\n",
63+
"from sklearn.model_selection import train_test_split\n",
64+
"import numpy as np\n",
65+
"\n",
66+
"dataset = datasets.load_iris()"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 2,
72+
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"x_train, x_test, y_train, y_test = train_test_split(dataset.data, dataset.target, test_size=0.2)"
76+
]
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"metadata": {},
81+
"source": [
82+
"## Train model"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"We can now train a Gaussian naive Bayes classifier using the full training dataset. This is the model that will be used to attack the training dataset later."
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 3,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"import sklearn.naive_bayes as naive_bayes\n",
99+
"from art.estimators.classification.scikitlearn import ScikitlearnGaussianNB\n",
100+
"\n",
101+
"model1 = naive_bayes.GaussianNB().fit(x_train, y_train)\n",
102+
"non_private_art = ScikitlearnGaussianNB(model1)"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": 4,
108+
"metadata": {},
109+
"outputs": [
110+
{
111+
"name": "stdout",
112+
"output_type": "stream",
113+
"text": [
114+
"Model accuracy (on the test dataset): 1.0\n"
115+
]
116+
}
117+
],
118+
"source": [
119+
"print(\"Model accuracy (on the test dataset): {}\".format(model1.score(x_test, y_test)))"
120+
]
121+
},
122+
{
123+
"cell_type": "markdown",
124+
"metadata": {},
125+
"source": [
126+
"## Launch and evaluate attack"
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"metadata": {},
132+
"source": [
133+
"We now select a row from the training dataset that we will remove. This is the **target row** which the attack will seek to reconstruct. The attacker will have access to `x_public` and `y_public`."
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 5,
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"target_row = int(np.random.random() * x_train.shape[0])\n",
143+
"\n",
144+
"x_public = np.delete(x_train, target_row, axis=0)\n",
145+
"y_public = np.delete(y_train, target_row, axis=0)"
146+
]
147+
},
148+
{
149+
"cell_type": "markdown",
150+
"metadata": {},
151+
"source": [
152+
"We can now launch the attack, and seek to infer the value of the target row. This is typically completed in less than a second."
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 6,
158+
"metadata": {},
159+
"outputs": [
160+
{
161+
"ename": "ImportError",
162+
"evalue": "cannot import name 'DatabaseReconstruction'",
163+
"output_type": "error",
164+
"traceback": [
165+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
166+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
167+
"\u001b[0;32m<ipython-input-6-2dd2e9a7664c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mart\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minference\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDatabaseReconstruction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mdbrecon\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatabaseReconstruction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_private_art\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdbrecon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreconstruct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_public\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_public\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
168+
"\u001b[0;31mImportError\u001b[0m: cannot import name 'DatabaseReconstruction'"
169+
]
170+
}
171+
],
172+
"source": [
173+
"from art.attacks.inference import DatabaseReconstruction\n",
174+
"\n",
175+
"dbrecon = DatabaseReconstruction(non_private_art)\n",
176+
"\n",
177+
"x, y = dbrecon.reconstruct(x_public, y_public)"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"We can evaluate the accuracy of the attack using root-mean-square error (RMSE), showing a high level of accuracy in the inferred value."
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"metadata": {},
191+
"outputs": [],
192+
"source": [
193+
"print(\"Inference RMSE: {}\".format(\n",
194+
" np.sqrt(((x_train[target_row] - x) ** 2).sum() / x_train.shape[1])))"
195+
]
196+
},
197+
{
198+
"cell_type": "markdown",
199+
"metadata": {},
200+
"source": [
201+
"We can confirm that the attack also inferred the correct label `y`."
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"np.argmax(y) == y_train[target_row]"
211+
]
212+
},
213+
{
214+
"cell_type": "markdown",
215+
"metadata": {},
216+
"source": [
217+
"# Attacking a model trained with differential privacy"
218+
]
219+
},
220+
{
221+
"cell_type": "markdown",
222+
"metadata": {},
223+
"source": [
224+
"We can mitigate against this attack by training the public ML model with differential privacy. We will use [diffprivlib](https://github.com/IBM/differential-privacy-library) to train a differentially private Guassian naive Bayes classifier. We can mitigate against any loss in accuracy of the model by choosing an `epsilon` value appropriate to our needs."
225+
]
226+
},
227+
{
228+
"cell_type": "markdown",
229+
"metadata": {},
230+
"source": [
231+
"## Train the model"
232+
]
233+
},
234+
{
235+
"cell_type": "code",
236+
"execution_count": null,
237+
"metadata": {},
238+
"outputs": [],
239+
"source": [
240+
"from diffprivlib import models\n",
241+
"\n",
242+
"model2 = models.GaussianNB(bounds=([4.3, 2.0, 1.1, 0.1], [7.9, 4.4, 6.9, 2.5]), epsilon=3).fit(x_train, y_train)\n",
243+
"private_art = ScikitlearnGaussianNB(model2)\n",
244+
"\n",
245+
"model2.score(x_test, y_test)"
246+
]
247+
},
248+
{
249+
"cell_type": "markdown",
250+
"metadata": {},
251+
"source": [
252+
"## Launch and evaluate attack"
253+
]
254+
},
255+
{
256+
"cell_type": "markdown",
257+
"metadata": {},
258+
"source": [
259+
"We then launch the same attack as before. In this case, the attack may take a number of seconds to return a result."
260+
]
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": null,
265+
"metadata": {},
266+
"outputs": [],
267+
"source": [
268+
"dbrecon = DatabaseReconstruction(private_art)\n",
269+
"\n",
270+
"x_dp, y_dp = dbrecon.reconstruct(x_public, y_public)"
271+
]
272+
},
273+
{
274+
"cell_type": "markdown",
275+
"metadata": {},
276+
"source": [
277+
"In this case, the RMSE shows our attack has not been as successful"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {},
284+
"outputs": [],
285+
"source": [
286+
"print(\"Inference RMSE (with differential privacy): {}\".format(\n",
287+
" np.sqrt(((x_train[target_row] - x_dp) ** 2).sum() / x_train.shape[1])))"
288+
]
289+
},
290+
{
291+
"cell_type": "markdown",
292+
"metadata": {},
293+
"source": [
294+
"This is confirmed by inspecting the inferred value and the true value."
295+
]
296+
},
297+
{
298+
"cell_type": "code",
299+
"execution_count": null,
300+
"metadata": {
301+
"scrolled": false
302+
},
303+
"outputs": [],
304+
"source": [
305+
"x_dp, x_train[target_row]"
306+
]
307+
},
308+
{
309+
"cell_type": "markdown",
310+
"metadata": {},
311+
"source": [
312+
"In fact, the attack may not even be able to correctly infer the target label."
313+
]
314+
},
315+
{
316+
"cell_type": "code",
317+
"execution_count": null,
318+
"metadata": {},
319+
"outputs": [],
320+
"source": [
321+
"np.argmax(y_dp), y_train[target_row]"
322+
]
323+
}
324+
],
325+
"metadata": {
326+
"kernelspec": {
327+
"display_name": "Python 3",
328+
"language": "python",
329+
"name": "python3"
330+
},
331+
"language_info": {
332+
"codemirror_mode": {
333+
"name": "ipython",
334+
"version": 3
335+
},
336+
"file_extension": ".py",
337+
"mimetype": "text/x-python",
338+
"name": "python",
339+
"nbconvert_exporter": "python",
340+
"pygments_lexer": "ipython3",
341+
"version": "3.6.8"
342+
}
343+
},
344+
"nbformat": 4,
345+
"nbformat_minor": 4
346+
}

0 commit comments

Comments
 (0)