Skip to content

Commit 1792db8

Browse files
authored
Merge pull request #768 from abigailgold/dev_1.5.0_OneHotAttribute
One-hot attribute inference
2 parents cba5b32 + 2126f95 commit 1792db8

File tree

5 files changed

+129
-18
lines changed

5 files changed

+129
-18
lines changed

art/attacks/attack.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class AttributeInferenceAttack(InferenceAttack):
317317

318318
attack_params = InferenceAttack.attack_params + ["attack_feature"]
319319

320-
def __init__(self, estimator, attack_feature: int = 0):
320+
def __init__(self, estimator, attack_feature: Union[int, slice] = 0):
321321
"""
322322
:param estimator: A trained estimator targeted for inference attack.
323323
:type estimator: :class:`.art.estimators.estimator.BaseEstimator`
@@ -346,10 +346,6 @@ def set_params(self, **kwargs) -> None:
346346
super().set_params(**kwargs)
347347
self._check_params()
348348

349-
def _check_params(self) -> None:
350-
if self.attack_feature < 0:
351-
raise ValueError("Attack feature must be positive.")
352-
353349

354350
class ReconstructionAttack(Attack):
355351
"""

art/attacks/inference/attribute_inference/black_box.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from __future__ import absolute_import, division, print_function, unicode_literals
2323

2424
import logging
25-
from typing import Optional, TYPE_CHECKING
25+
from typing import Optional, Union, TYPE_CHECKING
2626

2727
import numpy as np
2828
from sklearn.neural_network import MLPClassifier
@@ -51,16 +51,24 @@ class AttributeInferenceBlackBox(AttributeInferenceAttack):
5151
_estimator_requirements = (BaseEstimator, ClassifierMixin)
5252

5353
def __init__(
54-
self, classifier: "CLASSIFIER_TYPE", attack_model: Optional["CLASSIFIER_TYPE"] = None, attack_feature: int = 0
54+
self,
55+
classifier: "CLASSIFIER_TYPE",
56+
attack_model: Optional["CLASSIFIER_TYPE"] = None,
57+
attack_feature: Union[int, slice] = 0,
5558
):
5659
"""
5760
Create an AttributeInferenceBlackBox attack instance.
5861
5962
:param classifier: Target classifier.
6063
:param attack_model: The attack model to train, optional. If none is provided, a default model will be created.
61-
:param attack_feature: The index of the feature to be attacked.
64+
:param attack_feature: The index of the feature to be attacked or a slice representing multiple indexes in
65+
case of a one-hot encoded feature.
6266
"""
6367
super().__init__(estimator=classifier, attack_feature=attack_feature)
68+
if isinstance(self.attack_feature, int):
69+
self.single_index_feature = True
70+
else:
71+
self.single_index_feature = False
6472

6573
if attack_model:
6674
if ClassifierMixin not in type(attack_model).__mro__:
@@ -104,16 +112,18 @@ def fit(self, x: np.ndarray) -> None:
104112
# Checks:
105113
if self.estimator.input_shape[0] != x.shape[1]:
106114
raise ValueError("Shape of x does not match input_shape of classifier")
107-
if self.attack_feature >= x.shape[1]:
115+
if self.single_index_feature and self.attack_feature >= x.shape[1]:
108116
raise ValueError("attack_feature must be a valid index to a feature in x")
109117

110118
# get model's predictions for x
111119
predictions = np.array([np.argmax(arr) for arr in self.estimator.predict(x)]).reshape(-1, 1)
112120

113121
# get vector of attacked feature
114122
y = x[:, self.attack_feature]
115-
y_one_hot = float_to_categorical(y)
116-
y_ready = check_and_transform_label_format(y_one_hot, len(np.unique(y)), return_one_hot=True)
123+
y_ready = y
124+
if self.single_index_feature:
125+
y_one_hot = float_to_categorical(y)
126+
y_ready = check_and_transform_label_format(y_one_hot, len(np.unique(y)), return_one_hot=True)
117127

118128
# create training set for attack model
119129
x_train = np.concatenate((np.delete(x, self.attack_feature, 1), predictions), axis=1).astype(np.float32)
@@ -127,18 +137,27 @@ def infer(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
127137
128138
:param x: Input to attack. Includes all features except the attacked feature.
129139
:param y: Original model's predictions for x.
130-
:param values: Possible values for attacked feature.
140+
:param values: Possible values for attacked feature. Only needed in case of categorical feature (not one-hot).
131141
:type values: `np.ndarray`
132142
:return: The inferred feature values.
133143
"""
134144
if y.shape[0] != x.shape[0]:
135145
raise ValueError("Number of rows in x and y do not match")
136-
if self.estimator.input_shape[0] != x.shape[1] + 1:
146+
if self.single_index_feature and self.estimator.input_shape[0] != x.shape[1] + 1:
137147
raise ValueError("Number of features in x + 1 does not match input_shape of classifier")
138148

139-
if "values" not in kwargs.keys():
140-
raise ValueError("Missing parameter `values`.")
141-
values: np.ndarray = kwargs.get("values")
142-
143149
x_test = np.concatenate((x, y), axis=1).astype(np.float32)
144-
return np.array([values[np.argmax(arr)] for arr in self.attack_model.predict(x_test)])
150+
151+
if self.single_index_feature:
152+
if "values" not in kwargs.keys():
153+
raise ValueError("Missing parameter `values`.")
154+
values: np.ndarray = kwargs.get("values")
155+
return np.array([values[np.argmax(arr)] for arr in self.attack_model.predict(x_test)])
156+
else:
157+
return np.array(self.attack_model.predict(x_test))
158+
159+
def _check_params(self) -> None:
160+
if not isinstance(self.attack_feature, int) and not isinstance(self.attack_feature, slice):
161+
raise ValueError("Attack feature must be either an integer or a slice object.")
162+
if isinstance(self.attack_feature, int) and self.attack_feature < 0:
163+
raise ValueError("Attack feature index must be positive.")

art/attacks/inference/attribute_inference/white_box_decision_tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, classifier: ScikitlearnDecisionTreeClassifier, attack_feature
5555
:param attack_feature: The index of the feature to be attacked.
5656
"""
5757
super().__init__(estimator=classifier, attack_feature=attack_feature)
58+
self._check_params()
5859

5960
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
6061
"""
@@ -138,3 +139,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
138139
for index, value in enumerate(predicted_pred)
139140
]
140141
)
142+
143+
def _check_params(self) -> None:
144+
if self.attack_feature < 0:
145+
raise ValueError("Attack feature must be positive.")

art/attacks/inference/attribute_inference/white_box_lifestyle_decision_tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, classifier: "CLASSIFIER_TYPE", attack_feature: int = 0):
5555
:param attack_feature: The index of the feature to be attacked.
5656
"""
5757
super().__init__(estimator=classifier, attack_feature=attack_feature)
58+
self._check_params()
5859

5960
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
6061
"""
@@ -130,3 +131,7 @@ def _calculate_phi(self, x, values, n_samples):
130131
phi.append(num_value)
131132

132133
return phi
134+
135+
def _check_params(self) -> None:
136+
if self.attack_feature < 0:
137+
raise ValueError("Attack feature must be positive.")

tests/attacks/inference/attribute_inference/test_black_box.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
import numpy as np
2424
import torch.nn as nn
2525
import torch.optim as optim
26+
from sklearn.tree import DecisionTreeClassifier
2627

2728
from art.attacks.inference.attribute_inference.black_box import AttributeInferenceBlackBox
2829
from art.estimators.classification.pytorch import PyTorchClassifier
2930
from art.estimators.estimator import BaseEstimator
3031
from art.estimators.classification import ClassifierMixin
32+
from art.estimators.classification.scikitlearn import ScikitlearnDecisionTreeClassifier
3133

3234
from tests.attacks.utils import backend_test_classifier_type_check_fail
3335
from tests.utils import ARTTestException
@@ -143,5 +145,89 @@ def transform_feature(x):
143145
art_warning(e)
144146

145147

148+
@pytest.mark.skipMlFramework("dl_frameworks")
149+
def test_black_box_one_hot(art_warning, get_iris_dataset):
150+
try:
151+
attack_feature = 2 # petal length
152+
153+
# need to transform attacked feature into categorical
154+
def transform_feature(x):
155+
x[x > 0.5] = 2
156+
x[(x > 0.2) & (x <= 0.5)] = 1
157+
x[x <= 0.2] = 0
158+
159+
(x_train_iris, y_train_iris), (x_test_iris, y_test_iris) = get_iris_dataset
160+
# training data without attacked feature
161+
x_train_for_attack = np.delete(x_train_iris, attack_feature, 1)
162+
# only attacked feature
163+
x_train_feature = x_train_iris[:, attack_feature].copy().reshape(-1, 1)
164+
transform_feature(x_train_feature)
165+
# transform to one-hot encoding
166+
train_one_hot = np.zeros((x_train_feature.size, int(x_train_feature.max()) + 1))
167+
train_one_hot[np.arange(x_train_feature.size), x_train_feature.reshape(1, -1).astype(int)] = 1
168+
# training data with attacked feature (after transformation)
169+
x_train = np.concatenate((x_train_for_attack[:, :attack_feature], train_one_hot), axis=1)
170+
x_train = np.concatenate((x_train, x_train_for_attack[:, attack_feature:]), axis=1)
171+
172+
y_train = np.array([np.argmax(y) for y in y_train_iris]).reshape(-1, 1)
173+
174+
# test data without attacked feature
175+
x_test_for_attack = np.delete(x_test_iris, attack_feature, 1)
176+
# only attacked feature
177+
x_test_feature = x_test_iris[:, attack_feature].copy().reshape(-1, 1)
178+
transform_feature(x_test_feature)
179+
# transform to one-hot encoding
180+
test_one_hot = np.zeros((x_test_feature.size, int(x_test_feature.max()) + 1))
181+
test_one_hot[np.arange(x_test_feature.size), x_test_feature.reshape(1, -1).astype(int)] = 1
182+
# test data with attacked feature (after transformation)
183+
x_test = np.concatenate((x_test_for_attack[:, :attack_feature], test_one_hot), axis=1)
184+
x_test = np.concatenate((x_test, x_test_for_attack[:, attack_feature:]), axis=1)
185+
186+
tree = DecisionTreeClassifier()
187+
tree.fit(x_train, y_train)
188+
classifier = ScikitlearnDecisionTreeClassifier(tree)
189+
190+
attack = AttributeInferenceBlackBox(classifier, attack_feature=slice(attack_feature, attack_feature + 3))
191+
# get original model's predictions
192+
x_train_predictions = np.array([np.argmax(arr) for arr in classifier.predict(x_train)]).reshape(-1, 1)
193+
x_test_predictions = np.array([np.argmax(arr) for arr in classifier.predict(x_test)]).reshape(-1, 1)
194+
# train attack model
195+
attack.fit(x_train)
196+
# infer attacked feature
197+
inferred_train = attack.infer(x_train_for_attack, x_train_predictions)
198+
inferred_test = attack.infer(x_test_for_attack, x_test_predictions)
199+
# check accuracy
200+
train_acc = np.sum(np.all(inferred_train == train_one_hot, axis=1)) / len(inferred_train)
201+
test_acc = np.sum(np.all(inferred_test == test_one_hot, axis=1)) / len(inferred_test)
202+
assert pytest.approx(0.9145, abs=0.03) == train_acc
203+
assert pytest.approx(0.9333, abs=0.03) == test_acc
204+
205+
except ARTTestException as e:
206+
art_warning(e)
207+
208+
209+
def test_errors(art_warning, tabular_dl_estimator_for_attack, get_iris_dataset):
210+
try:
211+
classifier = tabular_dl_estimator_for_attack(AttributeInferenceBlackBox)
212+
(x_train, y_train), (x_test, y_test) = get_iris_dataset
213+
214+
with pytest.raises(ValueError):
215+
AttributeInferenceBlackBox(classifier, attack_feature="a")
216+
with pytest.raises(ValueError):
217+
AttributeInferenceBlackBox(classifier, attack_feature=-3)
218+
attack = AttributeInferenceBlackBox(classifier, attack_feature=8)
219+
with pytest.raises(ValueError):
220+
attack.fit(x_train)
221+
attack = AttributeInferenceBlackBox(classifier)
222+
with pytest.raises(ValueError):
223+
attack.fit(np.delete(x_train, 1, 1))
224+
with pytest.raises(ValueError):
225+
attack.infer(x_train, y_test)
226+
with pytest.raises(ValueError):
227+
attack.infer(x_train, y_train)
228+
except ARTTestException as e:
229+
art_warning(e)
230+
231+
146232
def test_classifier_type_check_fail():
147233
backend_test_classifier_type_check_fail(AttributeInferenceBlackBox, (BaseEstimator, ClassifierMixin))

0 commit comments

Comments
 (0)