Skip to content

Commit f31d2ab

Browse files
author
TrojAISec
committed
applying binary classification patch supplied by Beat
1 parent 6e38ecf commit f31d2ab

File tree

3 files changed

+10
-15
lines changed

3 files changed

+10
-15
lines changed

art/attacks/inference/membership_inference/black_box.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,9 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
292292

293293
if inferred is not None:
294294
if not probabilities:
295-
inferred_return = inferred.reshape(-1).astype(np.int)
295+
inferred_return = np.round(inferred)
296296
else:
297-
inferred = inferred.reshape(-1)
298-
prob_0 = np.ones_like(inferred) - inferred
299-
inferred_return = np.stack((prob_0, inferred), axis=1)
297+
inferred_return = inferred
300298
else:
301299
raise ValueError("No data available.")
302300
elif not self.default_model:
@@ -305,13 +303,13 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
305303
if probabilities:
306304
inferred_return = pred
307305
else:
308-
inferred_return = np.array([np.argmax(arr) for arr in pred])
306+
inferred_return = np.round(pred)
309307
else:
310308
pred = self.attack_model.predict_proba(np.c_[features, y]) # type: ignore
311309
if probabilities:
312-
inferred_return = pred
310+
inferred_return = pred[:, [1]]
313311
else:
314-
inferred_return = np.array([np.argmax(arr) for arr in pred])
312+
inferred_return = np.round(pred[:, [1]])
315313

316314
return inferred_return
317315

tests/attacks/inference/membership_inference/test_black_box.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_black_box_with_model(art_warning, tabular_dl_estimator_for_attack, esti
122122
try:
123123
classifier = tabular_dl_estimator_for_attack(MembershipInferenceBlackBox)
124124
attack_model = estimator_for_attack(num_features=2 * num_classes_iris)
125-
print(type(attack_model).__name__)
126125
attack = MembershipInferenceBlackBox(classifier, attack_model=attack_model)
127126
backend_check_membership_accuracy(attack, get_iris_dataset, attack_train_ratio, 0.25)
128127
except ARTTestException as e:
@@ -153,7 +152,6 @@ def test_black_box_with_model_prob(
153152
try:
154153
classifier = tabular_dl_estimator_for_attack(MembershipInferenceBlackBox)
155154
attack_model = estimator_for_attack(num_features=2 * num_classes_iris)
156-
print(type(attack_model).__name__)
157155
attack = MembershipInferenceBlackBox(classifier, attack_model=attack_model)
158156
backend_check_membership_probabilities(attack, get_iris_dataset, attack_train_ratio)
159157
except ARTTestException as e:
@@ -230,6 +228,5 @@ def backend_check_membership_probabilities(attack, dataset, attack_train_ratio):
230228

231229

232230
def backend_check_probabilities(pred, prob):
233-
assert prob.shape[1] == 2
234-
assert np.all(np.around(np.sum(prob, axis=1), decimals=5) == 1)
235-
assert np.all(np.argmax(prob, axis=1) == pred.astype(int))
231+
assert prob.shape[1] == 1
232+
assert np.all(np.round(prob) == pred.astype(int))

tests/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,8 +1631,8 @@ def get_attack_classifier_pt(num_features):
16311631
class AttackModel(nn.Module):
16321632
def __init__(self, num_features):
16331633
super(AttackModel, self).__init__()
1634-
self.layer = nn.Linear(num_features, 2)
1635-
self.output = nn.Softmax(dim=1)
1634+
self.layer = nn.Linear(num_features, 1)
1635+
self.output = nn.Sigmoid()
16361636

16371637
def forward(self, x):
16381638
return self.output(self.layer(x))
@@ -1641,7 +1641,7 @@ def forward(self, x):
16411641
model = AttackModel(num_features)
16421642

16431643
# Define a loss function and optimizer
1644-
loss_fn = nn.CrossEntropyLoss()
1644+
loss_fn = nn.BCELoss()
16451645
optimizer = optim.Adam(model.parameters(), lr=0.0001)
16461646
attack_model = PyTorchClassifier(
16471647
model=model, loss=loss_fn, optimizer=optimizer, input_shape=(num_features,), nb_classes=2

0 commit comments

Comments
 (0)