|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | | -from art.utils import load_mnist, projection, random_sphere, to_categorical |
| 8 | +from art.utils import load_mnist, projection, random_sphere, to_categorical, least_likely_class |
9 | 9 | from art.utils import random_targets, get_label_conf, get_labels_np_array, preprocess |
10 | 10 |
|
11 | 11 | logger = logging.getLogger('testLogger') |
@@ -74,6 +74,25 @@ def test_random_targets(self): |
74 | 74 | random_y = random_targets(y_, 10) |
75 | 75 | self.assertTrue(np.all(y != random_y.argmax(axis=1))) |
76 | 76 |
|
| 77 | + def test_least_likely_class(self): |
| 78 | + class DummyClassifier(): |
| 79 | + @property |
| 80 | + def nb_classes(self): |
| 81 | + return 4 |
| 82 | + |
| 83 | + def predict(self, x): |
| 84 | + fake_preds = [0.1, 0.2, 0.05, 0.65] |
| 85 | + return np.array([fake_preds] * x.shape[0]) |
| 86 | + |
| 87 | + batch_size = 5 |
| 88 | + x = np.random.rand(batch_size, 10, 10, 1) |
| 89 | + classifier = DummyClassifier() |
| 90 | + preds = least_likely_class(x, classifier) |
| 91 | + self.assertTrue(preds.shape == (batch_size, classifier.nb_classes)) |
| 92 | + |
| 93 | + expected_preds = np.array([[0, 0, 1, 0]] * batch_size) |
| 94 | + self.assertTrue((preds == expected_preds).all()) |
| 95 | + |
77 | 96 | def test_get_label_conf(self): |
78 | 97 | y = np.array([3, 1, 4, 1, 5, 9]) |
79 | 98 | y_ = to_categorical(y) |
|
0 commit comments