Skip to content

Commit 017c711

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Add utility function for least likely class targets
1 parent d16110b commit 017c711

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

art/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,21 @@ def random_targets(labels, nb_classes):
124124
return to_categorical(result, nb_classes)
125125

126126

127+
def least_likely_class(x, classifier):
128+
"""
129+
Compute the least likely class predictions for sample `x`. This strategy for choosing attack targets was used in
130+
(Kurakin et al., 2016). See https://arxiv.org/abs/1607.02533.
131+
132+
:param x: A data sample of shape accepted by `classifier`.
133+
:type x: `np.ndarray`
134+
:param classifier: The classifier used for computing predictions.
135+
:type classifier: `Classifier`
136+
:return: Least-likely class predicted by `classifier` for sample `x` in one-hot encoding.
137+
:rtype: `np.ndarray`
138+
"""
139+
return to_categorical(np.argmin(classifier.predict(x), axis=1), nb_classes=classifier.nb_classes)
140+
141+
127142
def get_label_conf(y_vec):
128143
"""
129144
Returns the confidence and the label of the most probable class given a vector of class confidences

art/utils_unittest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from art.utils import load_mnist, projection, random_sphere, to_categorical
8+
from art.utils import load_mnist, projection, random_sphere, to_categorical, least_likely_class
99
from art.utils import random_targets, get_label_conf, get_labels_np_array, preprocess
1010

1111
logger = logging.getLogger('testLogger')
@@ -74,6 +74,25 @@ def test_random_targets(self):
7474
random_y = random_targets(y_, 10)
7575
self.assertTrue(np.all(y != random_y.argmax(axis=1)))
7676

77+
def test_least_likely_class(self):
78+
class DummyClassifier():
79+
@property
80+
def nb_classes(self):
81+
return 4
82+
83+
def predict(self, x):
84+
fake_preds = [0.1, 0.2, 0.05, 0.65]
85+
return np.array([fake_preds] * x.shape[0])
86+
87+
batch_size = 5
88+
x = np.random.rand(batch_size, 10, 10, 1)
89+
classifier = DummyClassifier()
90+
preds = least_likely_class(x, classifier)
91+
self.assertTrue(preds.shape == (batch_size, classifier.nb_classes))
92+
93+
expected_preds = np.array([[0, 0, 1, 0]] * batch_size)
94+
self.assertTrue((preds == expected_preds).all())
95+
7796
def test_get_label_conf(self):
7897
y = np.array([3, 1, 4, 1, 5, 9])
7998
y_ = to_categorical(y)

0 commit comments

Comments
 (0)