Skip to content

Commit bffd4b9

Browse files
authored
Merge pull request #1132 from abigailgold/dev_1.7.0_new_attr_attack
Attribute inference attack that uses membership inference attack
2 parents 9ea2c2b + b9e394e commit bffd4b9

File tree

15 files changed

+821
-56
lines changed

15 files changed

+821
-56
lines changed

art/attacks/attack.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,12 @@ def __init__(self, estimator):
329329
@abc.abstractmethod
330330
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
331331
"""
332-
Infer sensitive properties (attributes, membership training records) from the targeted estimator. This method
332+
Infer sensitive attributes from the targeted estimator. This method
333333
should be overridden by all concrete inference attack implementations.
334334
335335
:param x: An array with reference inputs to be used in the attack.
336336
:param y: Labels for `x`. This parameter is only used by some of the attacks.
337-
:return: An array holding the inferred properties.
337+
:return: An array holding the inferred attribute values.
338338
"""
339339
raise NotImplementedError
340340

@@ -358,12 +358,41 @@ def __init__(self, estimator, attack_feature: Union[int, slice] = 0):
358358
@abc.abstractmethod
359359
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
360360
"""
361-
Infer sensitive properties (attributes, membership training records) from the targeted estimator. This method
361+
Infer sensitive attributes from the targeted estimator. This method
362362
should be overridden by all concrete inference attack implementations.
363363
364364
:param x: An array with reference inputs to be used in the attack.
365365
:param y: Labels for `x`. This parameter is only used by some of the attacks.
366-
:return: An array holding the inferred properties.
366+
:return: An array holding the inferred attribute values.
367+
"""
368+
raise NotImplementedError
369+
370+
371+
class MembershipInferenceAttack(InferenceAttack):
372+
"""
373+
Abstract base class for membership inference attack classes.
374+
"""
375+
376+
def __init__(self, estimator: Union["CLASSIFIER_TYPE"]):
377+
"""
378+
:param estimator: A trained estimator targeted for inference attack.
379+
:type estimator: :class:`.art.estimators.estimator.BaseEstimator`
380+
:param attack_feature: The index of the feature to be attacked.
381+
"""
382+
super().__init__(estimator)
383+
384+
@abc.abstractmethod
385+
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
386+
"""
387+
Infer membership status of samples from the target estimator. This method
388+
should be overridden by all concrete inference attack implementations.
389+
390+
:param x: An array with reference inputs to be used in the attack.
391+
:param y: Labels for `x`. This parameter is only used by some of the attacks.
392+
:param probabilities: a boolean indicating whether to return the predicted probabilities per class, or just
393+
the predicted class.
394+
:return: An array holding the inferred membership status (1 indicates member of training set,
395+
0 indicates non-member) or class probabilities.
367396
"""
368397
raise NotImplementedError
369398

art/attacks/inference/attribute_inference/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from art.attacks.inference.attribute_inference.white_box_lifestyle_decision_tree import (
88
AttributeInferenceWhiteBoxLifestyleDecisionTree,
99
)
10+
from art.attacks.inference.attribute_inference.meminf_based import AttributeInferenceMembership

art/attacks/inference/attribute_inference/baseline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
128128
129129
:param x: Input to attack. Includes all features except the attacked feature.
130130
:param y: Not used in this attack.
131-
:param values: Possible values for attacked feature. Only needed in case of categorical feature (not one-hot).
132-
:type values: `np.ndarray`
131+
:param values: Possible values for attacked feature. For a single column feature this should be a simple list
132+
containing all possible values, in increasing order (the smallest value in the 0 index and so
133+
on). For a multi-column feature (for example 1-hot encoded and then scaled), this should be a
134+
list of lists, where each internal list represents a column (in increasing order) and the values
135+
represent the possible values for that column (in increasing order).
136+
:type values: list
133137
:return: The inferred feature values.
134138
"""
135139
x_test = x.astype(np.float32)

art/attacks/inference/attribute_inference/black_box.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,12 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
138138
139139
:param x: Input to attack. Includes all features except the attacked feature.
140140
:param y: Original model's predictions for x.
141-
:param values: Possible values for attacked feature. Only needed in case of categorical feature (not one-hot).
142-
:type values: `np.ndarray`
141+
:param values: Possible values for attacked feature. For a single column feature this should be a simple list
142+
containing all possible values, in increasing order (the smallest value in the 0 index and so
143+
on). For a multi-column feature (for example 1-hot encoded and then scaled), this should be a
144+
list of lists, where each internal list represents a column (in increasing order) and the values
145+
represent the possible values for that column (in increasing order).
146+
:type values: list
143147
:return: The inferred feature values.
144148
"""
145149
if y is None:
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements attribute inference attacks using membership inference attacks.
20+
"""
21+
from __future__ import absolute_import, division, print_function, unicode_literals
22+
23+
import logging
24+
from typing import Optional, Union, List, TYPE_CHECKING
25+
26+
import numpy as np
27+
28+
from art.estimators.estimator import BaseEstimator
29+
from art.estimators.classification.classifier import ClassifierMixin
30+
from art.attacks.attack import AttributeInferenceAttack, MembershipInferenceAttack
31+
from art.exceptions import EstimatorError
32+
33+
if TYPE_CHECKING:
34+
from art.utils import CLASSIFIER_TYPE
35+
36+
logger = logging.getLogger(__name__)
37+
38+
39+
class AttributeInferenceMembership(AttributeInferenceAttack):
40+
"""
41+
Implementation of a an attribute inference attack that utilizes a membership inference attack.
42+
43+
The idea is to find the target feature value that causes the membership inference attack to classify the sample
44+
as a member with the highest confidence.
45+
"""
46+
47+
_estimator_requirements = (BaseEstimator, ClassifierMixin)
48+
49+
def __init__(
50+
self,
51+
classifier: "CLASSIFIER_TYPE",
52+
membership_attack: MembershipInferenceAttack,
53+
attack_feature: Union[int, slice] = 0,
54+
):
55+
"""
56+
Create an AttributeInferenceMembership attack instance.
57+
58+
:param classifier: Target classifier.
59+
:param membership_attack: The membership inference attack to use. Should be fit/callibrated in advance, and
60+
should support returning probabilities.
61+
:param attack_feature: The index of the feature to be attacked or a slice representing multiple indexes in
62+
case of a one-hot encoded feature.
63+
"""
64+
super().__init__(estimator=classifier, attack_feature=attack_feature)
65+
if not all(t in type(classifier).__mro__ for t in membership_attack.estimator_requirements):
66+
raise EstimatorError(membership_attack, membership_attack.estimator_requirements, classifier)
67+
68+
self.membership_attack = membership_attack
69+
self._check_params()
70+
71+
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
72+
"""
73+
Infer the attacked feature.
74+
75+
:param x: Input to attack. Includes all features except the attacked feature.
76+
:param y: The labels expected by the membership attack.
77+
:param values: Possible values for attacked feature. For a single column feature this should be a simple list
78+
containing all possible values, in increasing order (the smallest value in the 0 index and so
79+
on). For a multi-column feature (for example 1-hot encoded and then scaled), this should be a
80+
list of lists, where each internal list represents a column (in increasing order) and the values
81+
represent the possible values for that column (in increasing order).
82+
:type values: list
83+
:return: The inferred feature values.
84+
"""
85+
if self.estimator.input_shape is not None:
86+
if isinstance(self.attack_feature, int) and self.estimator.input_shape[0] != x.shape[1] + 1:
87+
raise ValueError("Number of features in x + 1 does not match input_shape of classifier")
88+
89+
if "values" not in kwargs.keys():
90+
raise ValueError("Missing parameter `values`.")
91+
values: Optional[List] = kwargs.get("values")
92+
if not values:
93+
raise ValueError("`values` cannot be None or empty")
94+
95+
if y is not None:
96+
if y.shape[0] != x.shape[0]:
97+
raise ValueError("Number of rows in x and y do not match")
98+
99+
# assumes single index
100+
if isinstance(self.attack_feature, int):
101+
first = True
102+
for value in values:
103+
v_full = np.full((x.shape[0], 1), value).astype(np.float32)
104+
x_value = np.concatenate((x[:, : self.attack_feature], v_full), axis=1)
105+
x_value = np.concatenate((x_value, x[:, self.attack_feature :]), axis=1)
106+
107+
predicted = self.membership_attack.infer(x_value, y, probabilities=True)
108+
if first:
109+
probabilities = predicted[:, 1].reshape(-1, 1)
110+
first = False
111+
else:
112+
probabilities = np.hstack((probabilities, predicted[:, 1].reshape(-1, 1)))
113+
114+
# needs to be of type float so we can later replace back the actual values
115+
value_indexes = np.argmax(probabilities, axis=1).astype(np.float32)
116+
pred_values = np.zeros_like(value_indexes)
117+
for index, value in enumerate(values):
118+
pred_values[value_indexes == index] = value
119+
else: # 1-hot encoded feature. Can also be scaled.
120+
first = True
121+
# assumes that the second value is the "positive" value and that there can only be one positive column
122+
for index, value in enumerate(values):
123+
curr_value = np.zeros((x.shape[0], len(values)))
124+
curr_value[:, index] = value[1]
125+
for not_index, not_value in enumerate(values):
126+
if not_index != index:
127+
curr_value[:, not_index] = not_value[0]
128+
x_value = np.concatenate((x[:, : self.attack_feature.start], curr_value), axis=1)
129+
x_value = np.concatenate((x_value, x[:, self.attack_feature.start :]), axis=1)
130+
131+
predicted = self.membership_attack.infer(x_value, y, probabilities=True)
132+
if first:
133+
probabilities = predicted[:, 1].reshape(-1, 1)
134+
else:
135+
probabilities = np.hstack((probabilities, predicted[:, 1].reshape(-1, 1)))
136+
first = False
137+
value_indexes = np.argmax(probabilities, axis=1).astype(np.float32)
138+
pred_values = np.zeros_like(probabilities)
139+
for index, value in enumerate(values):
140+
curr_value = np.zeros(len(values))
141+
curr_value[index] = value[1]
142+
for not_index, not_value in enumerate(values):
143+
if not_index != index:
144+
curr_value[not_index] = not_value[0]
145+
pred_values[value_indexes == index] = curr_value
146+
return pred_values
147+
148+
def _check_params(self) -> None:
149+
if not isinstance(self.attack_feature, int) and not isinstance(self.attack_feature, slice):
150+
raise ValueError("Attack feature must be either an integer or a slice object.")
151+
if isinstance(self.attack_feature, int) and self.attack_feature < 0:
152+
raise ValueError("Attack feature index must be positive.")
153+
if not isinstance(self.membership_attack, MembershipInferenceAttack):
154+
raise ValueError("membership_attack should be a sub-class of MembershipInferenceAttack")

art/attacks/inference/attribute_inference/white_box_decision_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
6767
:param x: Input to attack. Includes all features except the attacked feature.
6868
:param y: Original model's predictions for x.
6969
:param values: Possible values for attacked feature.
70-
:type values: `np.ndarray`
70+
:type values: list
7171
:param priors: Prior distributions of attacked feature values. Same size array as `values`.
72-
:type priors: `np.ndarray`
72+
:type priors: list
7373
:return: The inferred feature values.
7474
"""
7575
if "priors" not in kwargs.keys():

art/attacks/inference/attribute_inference/white_box_lifestyle_decision_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
6464
:param x: Input to attack. Includes all features except the attacked feature.
6565
:param y: Not used.
6666
:param values: Possible values for attacked feature.
67-
:type values: `np.ndarray`
67+
:type values: list
6868
:param priors: Prior distributions of attacked feature values. Same size array as `values`.
69-
:type priors: `np.ndarray`
69+
:type priors: list
7070
:return: The inferred feature values.
7171
:rtype: `np.ndarray`
7272
"""

art/attacks/inference/membership_inference/black_box.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
3030

31-
from art.attacks.attack import InferenceAttack
31+
from art.attacks.attack import MembershipInferenceAttack
3232
from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin
3333
from art.estimators.classification.classifier import ClassifierMixin
3434
from art.utils import check_and_transform_label_format
@@ -39,15 +39,15 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42-
class MembershipInferenceBlackBox(InferenceAttack):
42+
class MembershipInferenceBlackBox(MembershipInferenceAttack):
4343
"""
4444
Implementation of a learned black-box membership inference attack.
4545
4646
This implementation can use as input to the learning process probabilities/logits or losses,
4747
depending on the type of model and provided configuration.
4848
"""
4949

50-
attack_params = InferenceAttack.attack_params + [
50+
attack_params = MembershipInferenceAttack.attack_params + [
5151
"input_type",
5252
"attack_model_type",
5353
"attack_model",
@@ -231,10 +231,7 @@ def fit( # pylint: disable=W0613
231231
loss.backward()
232232
optimizer.step()
233233
else:
234-
if self.attack_model_type == "gb":
235-
y_ready = check_and_transform_label_format(y_new, len(np.unique(y_new)), return_one_hot=False)
236-
else:
237-
y_ready = check_and_transform_label_format(y_new, len(np.unique(y_new)), return_one_hot=True)
234+
y_ready = check_and_transform_label_format(y_new, len(np.unique(y_new)), return_one_hot=False)
238235
self.attack_model.fit(np.c_[x_1, x_2], y_ready) # type: ignore
239236

240237
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
@@ -243,7 +240,10 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
243240
244241
:param x: Input records to attack.
245242
:param y: True labels for `x`.
246-
:return: An array holding the inferred membership status, 1 indicates a member and 0 indicates non-member.
243+
:param probabilities: a boolean indicating whether to return the predicted probabilities per class, or just
244+
the predicted class
245+
:return: An array holding the inferred membership status, 1 indicates a member and 0 indicates non-member,
246+
or class probabilities.
247247
"""
248248
if y is None:
249249
raise ValueError("MembershipInferenceBlackBox requires true labels `y`.")
@@ -252,6 +252,11 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
252252
if self.estimator.input_shape[0] != x.shape[1]:
253253
raise ValueError("Shape of x does not match input_shape of classifier")
254254

255+
if "probabilities" in kwargs.keys():
256+
probabilities = kwargs.get("probabilities")
257+
else:
258+
probabilities = False
259+
255260
y = check_and_transform_label_format(y, len(np.unique(y)), return_one_hot=True)
256261

257262
if y.shape[0] != x.shape[0]:
@@ -274,7 +279,10 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
274279
for input1, input2, _ in test_loader:
275280
input1, input2 = to_cuda(input1), to_cuda(input2)
276281
outputs = self.attack_model(input1, input2) # type: ignore
277-
predicted = torch.round(outputs)
282+
if not probabilities:
283+
predicted = torch.round(outputs)
284+
else:
285+
predicted = outputs
278286
predicted = from_cuda(predicted)
279287

280288
if inferred is None:
@@ -283,12 +291,27 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
283291
inferred = np.vstack((inferred, predicted.detach().numpy()))
284292

285293
if inferred is not None:
286-
inferred_return = inferred.reshape(-1).astype(np.int)
294+
if not probabilities:
295+
inferred_return = inferred.reshape(-1).astype(np.int)
296+
else:
297+
inferred = inferred.reshape(-1)
298+
prob_0 = np.ones_like(inferred) - inferred
299+
inferred_return = np.stack((prob_0, inferred), axis=1)
287300
else:
288301
raise ValueError("No data available.")
289-
else:
302+
elif not self.default_model:
303+
# assumes the predict method of the supplied model returns probabilities
290304
pred = self.attack_model.predict(np.c_[features, y]) # type: ignore
291-
inferred_return = np.array([np.argmax(arr) for arr in pred])
305+
if probabilities:
306+
inferred_return = pred
307+
else:
308+
inferred_return = np.array([np.argmax(arr) for arr in pred])
309+
else:
310+
pred = self.attack_model.predict_proba(np.c_[features, y]) # type: ignore
311+
if probabilities:
312+
inferred_return = pred
313+
else:
314+
inferred_return = np.array([np.argmax(arr) for arr in pred])
292315

293316
return inferred_return
294317

0 commit comments

Comments
 (0)