Skip to content

Commit ab11c98

Browse files
authored
Merge pull request #956 from abigailgold/dev_1.6.0_att_baseline
Add baseline attribute attack
2 parents 5421cdb + e1d3b70 commit ab11c98

File tree

5 files changed

+337
-26
lines changed

5 files changed

+337
-26
lines changed

art/attacks/inference/attribute_inference/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Module providing attribute inference attacks.
33
"""
44
from art.attacks.inference.attribute_inference.black_box import AttributeInferenceBlackBox
5+
from art.attacks.inference.attribute_inference.baseline import AttributeInferenceBaseline
56
from art.attacks.inference.attribute_inference.white_box_decision_tree import AttributeInferenceWhiteBoxDecisionTree
67
from art.attacks.inference.attribute_inference.white_box_lifestyle_decision_tree import (
78
AttributeInferenceWhiteBoxLifestyleDecisionTree,
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements attribute inference attacks.
20+
"""
21+
from __future__ import absolute_import, division, print_function, unicode_literals
22+
23+
import logging
24+
from typing import Optional, Union, TYPE_CHECKING
25+
26+
import numpy as np
27+
from sklearn.neural_network import MLPClassifier
28+
29+
from art.estimators.estimator import BaseEstimator
30+
from art.estimators.classification.classifier import ClassifierMixin
31+
from art.attacks.attack import AttributeInferenceAttack
32+
from art.utils import check_and_transform_label_format, float_to_categorical, floats_to_one_hot
33+
34+
if TYPE_CHECKING:
35+
from art.utils import CLASSIFIER_TYPE
36+
37+
logger = logging.getLogger(__name__)
38+
39+
40+
class AttributeInferenceBaseline(AttributeInferenceAttack):
41+
"""
42+
Implementation of a baseline attribute inference, not using a model.
43+
44+
The idea is to train a simple neural network to learn the attacked feature from the rest of the features. Should
45+
be used to compare with other attribute inference results.
46+
"""
47+
_estimator_requirements = ()
48+
49+
def __init__(
50+
self,
51+
attack_model: Optional["CLASSIFIER_TYPE"] = None,
52+
attack_feature: Union[int, slice] = 0,
53+
):
54+
"""
55+
Create an AttributeInferenceBaseline attack instance.
56+
57+
:param attack_model: The attack model to train, optional. If none is provided, a default model will be created.
58+
:param attack_feature: The index of the feature to be attacked or a slice representing multiple indexes in
59+
case of a one-hot encoded feature.
60+
"""
61+
super().__init__(estimator=None, attack_feature=attack_feature)
62+
63+
if isinstance(self.attack_feature, int):
64+
self.single_index_feature = True
65+
else:
66+
self.single_index_feature = False
67+
68+
if attack_model:
69+
if ClassifierMixin not in type(attack_model).__mro__:
70+
raise ValueError("Attack model must be of type Classifier.")
71+
self.attack_model = attack_model
72+
else:
73+
self.attack_model = MLPClassifier(
74+
hidden_layer_sizes=(100,),
75+
activation="relu",
76+
solver="adam",
77+
alpha=0.0001,
78+
batch_size="auto",
79+
learning_rate="constant",
80+
learning_rate_init=0.001,
81+
power_t=0.5,
82+
max_iter=2000,
83+
shuffle=True,
84+
random_state=None,
85+
tol=0.0001,
86+
verbose=False,
87+
warm_start=False,
88+
momentum=0.9,
89+
nesterovs_momentum=True,
90+
early_stopping=False,
91+
validation_fraction=0.1,
92+
beta_1=0.9,
93+
beta_2=0.999,
94+
epsilon=1e-08,
95+
n_iter_no_change=10,
96+
max_fun=15000,
97+
)
98+
self._check_params()
99+
100+
def fit(self, x: np.ndarray) -> None:
101+
"""
102+
Train the attack model.
103+
104+
:param x: Input to training process. Includes all features used to train the original model.
105+
"""
106+
107+
# Checks:
108+
if self.single_index_feature and self.attack_feature >= x.shape[1]:
109+
raise ValueError("attack_feature must be a valid index to a feature in x")
110+
111+
# get vector of attacked feature
112+
y = x[:, self.attack_feature]
113+
if self.single_index_feature:
114+
y_one_hot = float_to_categorical(y)
115+
else:
116+
y_one_hot = floats_to_one_hot(y)
117+
y_ready = check_and_transform_label_format(y_one_hot, len(np.unique(y)), return_one_hot=True)
118+
119+
# create training set for attack model
120+
x_train = np.delete(x, self.attack_feature, 1).astype(np.float32)
121+
122+
# train attack model
123+
self.attack_model.fit(x_train, y_ready)
124+
125+
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
126+
"""
127+
Infer the attacked feature.
128+
129+
:param x: Input to attack. Includes all features except the attacked feature.
130+
:param y: Not used in this attack.
131+
:param values: Possible values for attacked feature. Only needed in case of categorical feature (not one-hot).
132+
:type values: `np.ndarray`
133+
:return: The inferred feature values.
134+
"""
135+
x_test = x.astype(np.float32)
136+
137+
if self.single_index_feature:
138+
if "values" not in kwargs.keys():
139+
raise ValueError("Missing parameter `values`.")
140+
values: np.ndarray = kwargs.get("values")
141+
return np.array([values[np.argmax(arr)] for arr in self.attack_model.predict(x_test)])
142+
else:
143+
if "values" in kwargs.keys():
144+
values = kwargs.get("values")
145+
predictions = self.attack_model.predict(x_test).astype(np.float32)
146+
i = 0
147+
for column in predictions.T:
148+
for index in range(len(values[i])):
149+
np.place(column, [column == index], values[i][index])
150+
i += 1
151+
return np.array(predictions)
152+
else:
153+
return np.array(self.attack_model.predict(x_test))
154+
155+
def _check_params(self) -> None:
156+
if not isinstance(self.attack_feature, int) and not isinstance(self.attack_feature, slice):
157+
raise ValueError("Attack feature must be either an integer or a slice object.")
158+
if isinstance(self.attack_feature, int) and self.attack_feature < 0:
159+
raise ValueError("Attack feature index must be positive.")

notebooks/attack_attribute_inference.ipynb

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
},
3636
{
3737
"cell_type": "code",
38-
"execution_count": 3,
38+
"execution_count": 1,
3939
"metadata": {},
4040
"outputs": [],
4141
"source": [
@@ -57,14 +57,14 @@
5757
},
5858
{
5959
"cell_type": "code",
60-
"execution_count": 6,
60+
"execution_count": 2,
6161
"metadata": {},
6262
"outputs": [
6363
{
6464
"name": "stdout",
6565
"output_type": "stream",
6666
"text": [
67-
"Base model accuracy: 0.9489628557645924\n"
67+
"Base model accuracy: 0.9552339604438013\n"
6868
]
6969
}
7070
],
@@ -91,12 +91,12 @@
9191
},
9292
{
9393
"cell_type": "code",
94-
"execution_count": 8,
94+
"execution_count": 3,
9595
"metadata": {},
9696
"outputs": [],
9797
"source": [
9898
"import numpy as np\n",
99-
"from art.attacks.inference import AttributeInferenceBlackBox\n",
99+
"from art.attacks.inference.attribute_inference import AttributeInferenceBlackBox\n",
100100
"\n",
101101
"attack_feature = 1 # social\n",
102102
"\n",
@@ -123,14 +123,14 @@
123123
},
124124
{
125125
"cell_type": "code",
126-
"execution_count": 11,
126+
"execution_count": 4,
127127
"metadata": {},
128128
"outputs": [
129129
{
130130
"name": "stdout",
131131
"output_type": "stream",
132132
"text": [
133-
"0.7055191045928213\n"
133+
"0.6981860285604014\n"
134134
]
135135
}
136136
],
@@ -161,19 +161,19 @@
161161
},
162162
{
163163
"cell_type": "code",
164-
"execution_count": 13,
164+
"execution_count": 5,
165165
"metadata": {},
166166
"outputs": [
167167
{
168168
"name": "stdout",
169169
"output_type": "stream",
170170
"text": [
171-
"0.6526437668853724\n"
171+
"0.6522578155152451\n"
172172
]
173173
}
174174
],
175175
"source": [
176-
"from art.attacks.inference import AttributeInferenceWhiteBoxLifestyleDecisionTree\n",
176+
"from art.attacks.inference.attribute_inference import AttributeInferenceWhiteBoxLifestyleDecisionTree\n",
177177
"\n",
178178
"wb_attack = AttributeInferenceWhiteBoxLifestyleDecisionTree(art_classifier, attack_feature=attack_feature)\n",
179179
"\n",
@@ -196,19 +196,19 @@
196196
},
197197
{
198198
"cell_type": "code",
199-
"execution_count": 14,
199+
"execution_count": 6,
200200
"metadata": {},
201201
"outputs": [
202202
{
203203
"name": "stdout",
204204
"output_type": "stream",
205205
"text": [
206-
"0.7124662292551138\n"
206+
"0.713624083365496\n"
207207
]
208208
}
209209
],
210210
"source": [
211-
"from art.attacks.inference import AttributeInferenceWhiteBoxDecisionTree\n",
211+
"from art.attacks.inference.attribute_inference import AttributeInferenceWhiteBoxDecisionTree\n",
212212
"\n",
213213
"wb2_attack = AttributeInferenceWhiteBoxDecisionTree(art_classifier, attack_feature=attack_feature)\n",
214214
"\n",
@@ -231,16 +231,16 @@
231231
},
232232
{
233233
"cell_type": "code",
234-
"execution_count": 16,
234+
"execution_count": 7,
235235
"metadata": {},
236236
"outputs": [
237237
{
238238
"name": "stdout",
239239
"output_type": "stream",
240240
"text": [
241-
"(0.7638888888888888, 0.13110846245530394)\n",
242-
"(0.3849056603773585, 0.12157330154946365)\n",
243-
"(0.6666666666666666, 0.22407628128724671)\n"
241+
"(0.654054054054054, 0.14421930870083433)\n",
242+
"(0.3892857142857143, 0.1299165673420739)\n",
243+
"(0.6644067796610169, 0.23361144219308702)\n"
244244
]
245245
}
246246
],
@@ -276,6 +276,47 @@
276276
"# white-box 2\n",
277277
"print(calc_precision_recall(inferred_train_wb2, np.around(x_train_feature, decimals=8), positive_value=1.41404987))"
278278
]
279+
},
280+
{
281+
"cell_type": "markdown",
282+
"metadata": {},
283+
"source": [
284+
"To verify the significance of these results, we now run a baseline attack that uses only the remaining features to try to predict the value of the attacked feature, with no use of the model itself."
285+
]
286+
},
287+
{
288+
"cell_type": "code",
289+
"execution_count": 9,
290+
"metadata": {},
291+
"outputs": [
292+
{
293+
"name": "stdout",
294+
"output_type": "stream",
295+
"text": [
296+
"0.6761868004631416\n"
297+
]
298+
}
299+
],
300+
"source": [
301+
"from art.attacks.inference.attribute_inference import AttributeInferenceBaseline\n",
302+
"\n",
303+
"baseline_attack = AttributeInferenceBaseline(attack_feature=attack_feature)\n",
304+
"\n",
305+
"# train attack model\n",
306+
"baseline_attack.fit(x_test)\n",
307+
"# infer values\n",
308+
"inferred_train_baseline = baseline_attack.infer(x_train_for_attack, values=values)\n",
309+
"# check accuracy\n",
310+
"baseline_train_acc = np.sum(inferred_train_baseline == np.around(x_train_feature, decimals=8).reshape(1,-1)) / len(inferred_train_baseline)\n",
311+
"print(baseline_train_acc)"
312+
]
313+
},
314+
{
315+
"cell_type": "markdown",
316+
"metadata": {},
317+
"source": [
318+
"We can see that both the black-box attack and the second white-box attack do slightly better than the baseline."
319+
]
279320
}
280321
],
281322
"metadata": {
@@ -294,7 +335,7 @@
294335
"name": "python",
295336
"nbconvert_exporter": "python",
296337
"pygments_lexer": "ipython3",
297-
"version": "3.7.1"
338+
"version": "3.8.3"
298339
}
299340
},
300341
"nbformat": 4,

0 commit comments

Comments
 (0)