Skip to content

Commit 820ddaf

Browse files
author
TrojAISec
committed
another patch from Beat to fix meminf tests
1 parent f31d2ab commit 820ddaf

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

art/attacks/inference/attribute_inference/meminf_based.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
Create an AttributeInferenceMembership attack instance.
5757
5858
:param classifier: Target classifier.
59-
:param membership_attack: The membership inference attack to use. Should be fit/callibrated in advance, and
59+
:param membership_attack: The membership inference attack to use. Should be fit/calibrated in advance, and
6060
should support returning probabilities.
6161
:param attack_feature: The index of the feature to be attacked or a slice representing multiple indexes in
6262
case of a one-hot encoded feature.
@@ -106,10 +106,10 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
106106

107107
predicted = self.membership_attack.infer(x_value, y, probabilities=True)
108108
if first:
109-
probabilities = predicted[:, 1].reshape(-1, 1)
109+
probabilities = predicted
110110
first = False
111111
else:
112-
probabilities = np.hstack((probabilities, predicted[:, 1].reshape(-1, 1)))
112+
probabilities = np.hstack((probabilities, predicted))
113113

114114
# needs to be of type float so we can later replace back the actual values
115115
value_indexes = np.argmax(probabilities, axis=1).astype(np.float32)
@@ -130,9 +130,9 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
130130

131131
predicted = self.membership_attack.infer(x_value, y, probabilities=True)
132132
if first:
133-
probabilities = predicted[:, 1].reshape(-1, 1)
133+
probabilities = predicted
134134
else:
135-
probabilities = np.hstack((probabilities, predicted[:, 1].reshape(-1, 1)))
135+
probabilities = np.hstack((probabilities, predicted))
136136
first = False
137137
value_indexes = np.argmax(probabilities, axis=1).astype(np.float32)
138138
pred_values = np.zeros_like(probabilities)

tests/attacks/inference/attribute_inference/test_meminf_based.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def transform_feature(x):
187187
inferred_train = attack.infer(x_train_for_attack, y_train_iris, values=values)
188188
inferred_test = attack.infer(x_test_for_attack, y_test_iris, values=values)
189189
# check accuracy
190-
train_acc = np.sum(inferred_train == x_train_feature.reshape(1, -1)) / len(inferred_train)
191-
test_acc = np.sum(inferred_test == x_test_feature.reshape(1, -1)) / len(inferred_test)
190+
train_acc = np.sum(inferred_train == x_train_feature) / len(inferred_train)
191+
test_acc = np.sum(inferred_test == x_test_feature) / len(inferred_test)
192192
assert 0.1 <= train_acc
193193
assert 0.1 <= test_acc
194194

@@ -325,18 +325,18 @@ def transform_feature(x):
325325
attack_train_ratio = 0.5
326326
attack_train_size = int(len(x_train) * attack_train_ratio)
327327
attack_test_size = int(len(x_test) * attack_train_ratio)
328-
# attack without callibration
328+
# attack without calibration
329329
attack = AttributeInferenceMembership(classifier, meminf_attack, attack_feature=attack_feature)
330330
# infer attacked feature
331331
inferred_train = attack.infer(x_train_for_attack, y_train_iris, values=values)
332332
inferred_test = attack.infer(x_test_for_attack, y_test_iris, values=values)
333333
# check accuracy
334-
train_acc = np.sum(inferred_train == x_train_feature.reshape(1, -1)) / len(inferred_train)
335-
test_acc = np.sum(inferred_test == x_test_feature.reshape(1, -1)) / len(inferred_test)
334+
train_acc = np.sum(inferred_train == x_train_feature) / len(inferred_train)
335+
test_acc = np.sum(inferred_test == x_test_feature) / len(inferred_test)
336336
assert 0.5 <= train_acc
337337
assert 0.5 <= test_acc
338338

339-
# attack with callibration
339+
# attack with calibration
340340
meminf_attack.calibrate_distance_threshold(
341341
x_train[:attack_train_size],
342342
y_train_iris[:attack_train_size],
@@ -349,8 +349,8 @@ def transform_feature(x):
349349
inferred_train = attack.infer(x_train_for_attack, y_train_iris, values=values)
350350
inferred_test = attack.infer(x_test_for_attack, y_test_iris, values=values)
351351
# check accuracy
352-
train_acc = np.sum(inferred_train == x_train_feature.reshape(1, -1)) / len(inferred_train)
353-
test_acc = np.sum(inferred_test == x_test_feature.reshape(1, -1)) / len(inferred_test)
352+
train_acc = np.sum(inferred_train == x_train_feature) / len(inferred_train)
353+
test_acc = np.sum(inferred_test == x_test_feature) / len(inferred_test)
354354
assert 0.1 <= train_acc
355355
assert 0.1 <= test_acc
356356

0 commit comments

Comments
 (0)