|
26 | 26 | import numpy as np |
27 | 27 | from tqdm.auto import trange |
28 | 28 |
|
| 29 | +from art.utils import check_and_transform_label_format |
| 30 | + |
29 | 31 | if TYPE_CHECKING: |
30 | 32 | from art.estimators.classification.classifier import ClassifierDecisionTree |
31 | 33 |
|
@@ -192,16 +194,25 @@ def verify( |
192 | 194 | Verify the robustness of the classifier on the dataset `(x, y)`. |
193 | 195 |
|
194 | 196 | :param x: Feature data of shape `(nb_samples, nb_features)`. |
195 | | - :param y: Labels, one-vs-rest encoding of shape `(nb_samples, nb_classes)`. |
| 197 | + :param y: Labels, one-hot-encoded of shape `(nb_samples, nb_classes)` or indices of shape |
| 198 | + (nb_samples,)`. |
196 | 199 | :param eps_init: Attack budget for the first search step. |
197 | 200 | :param norm: The norm to apply epsilon. |
198 | 201 | :param nb_search_steps: The number of search steps. |
199 | 202 | :param max_clique: The maximum number of nodes in a clique. |
200 | 203 | :param max_level: The maximum number of clique search levels. |
201 | 204 | :return: A tuple of the average robustness bound and the verification error at `eps`. |
202 | 205 | """ |
| 206 | + if np.min(x) < 0 or np.max(x) > 1: |
| 207 | + raise ValueError( |
| 208 | + "There are features not in the range [0, 1]. The current implementation only supports normalized input" |
| 209 | + "values in range [0 1]." |
| 210 | + ) |
| 211 | + |
203 | 212 | self.x: np.ndarray = x |
204 | | - self.y: np.ndarray = np.argmax(y, axis=1) |
| 213 | + self.y: np.ndarray = check_and_transform_label_format( |
| 214 | + y, nb_classes=self._classifier.nb_classes, return_one_hot=False |
| 215 | + ) |
205 | 216 | self.max_clique: int = max_clique |
206 | 217 | self.max_level: int = max_level |
207 | 218 |
|
|
0 commit comments