Skip to content

Commit ef53a70

Browse files
authored
Merge pull request #1890 from keykholt/verification_fix
Small clarification fix for verification alg
2 parents 0261594 + 4559032 commit ef53a70

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

art/metrics/verification_decisions_trees.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import numpy as np
2727
from tqdm.auto import trange
2828

29+
from art.utils import check_and_transform_label_format
30+
2931
if TYPE_CHECKING:
3032
from art.estimators.classification.classifier import ClassifierDecisionTree
3133

@@ -192,16 +194,25 @@ def verify(
192194
Verify the robustness of the classifier on the dataset `(x, y)`.
193195
194196
:param x: Feature data of shape `(nb_samples, nb_features)`.
195-
:param y: Labels, one-vs-rest encoding of shape `(nb_samples, nb_classes)`.
197+
:param y: Labels, one-hot-encoded of shape `(nb_samples, nb_classes)` or indices of shape
198+
(nb_samples,)`.
196199
:param eps_init: Attack budget for the first search step.
197200
:param norm: The norm to apply epsilon.
198201
:param nb_search_steps: The number of search steps.
199202
:param max_clique: The maximum number of nodes in a clique.
200203
:param max_level: The maximum number of clique search levels.
201204
:return: A tuple of the average robustness bound and the verification error at `eps`.
202205
"""
206+
if np.min(x) < 0 or np.max(x) > 1:
207+
raise ValueError(
208+
"There are features not in the range [0, 1]. The current implementation only supports normalized input"
209+
"values in range [0 1]."
210+
)
211+
203212
self.x: np.ndarray = x
204-
self.y: np.ndarray = np.argmax(y, axis=1)
213+
self.y: np.ndarray = check_and_transform_label_format(
214+
y, nb_classes=self._classifier.nb_classes, return_one_hot=False
215+
)
205216
self.max_clique: int = max_clique
206217
self.max_level: int = max_level
207218

0 commit comments

Comments
 (0)