Skip to content

Commit 4dfd585

Browse files
author
Kevin Eykholt
committed
Small clarification fix for verification alg
Signed-off-by: Kevin Eykholt <[email protected]>
1 parent 3e3a438 commit 4dfd585

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

art/metrics/verification_decisions_trees.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import numpy as np
2727
from tqdm.auto import trange
2828

29+
from art.utils import check_and_transform_label_format
30+
2931
if TYPE_CHECKING:
3032
from art.estimators.classification.classifier import ClassifierDecisionTree
3133

@@ -192,16 +194,20 @@ def verify(
192194
Verify the robustness of the classifier on the dataset `(x, y)`.
193195
194196
:param x: Feature data of shape `(nb_samples, nb_features)`.
195-
:param y: Labels, one-vs-rest encoding of shape `(nb_samples, nb_classes)`.
197+
:param y: Labels, one-hot-encoded of shape `(nb_samples, nb_classes)` or indices of shape
198+
(nb_samples,)`.
196199
:param eps_init: Attack budget for the first search step.
197200
:param norm: The norm to apply epsilon.
198201
:param nb_search_steps: The number of search steps.
199202
:param max_clique: The maximum number of nodes in a clique.
200203
:param max_level: The maximum number of clique search levels.
201204
:return: A tuple of the average robustness bound and the verification error at `eps`.
202205
"""
206+
if np.min(x) < 0 or np.max(x) > 1:
207+
raise ValueError("There are features not in the range [0,1].")
208+
203209
self.x: np.ndarray = x
204-
self.y: np.ndarray = np.argmax(y, axis=1)
210+
self.y: np.ndarray = check_and_transform_label_format(y, nb_classes=self._classifier.nb_classes, return_one_hot=False)
205211
self.max_clique: int = max_clique
206212
self.max_level: int = max_level
207213

0 commit comments

Comments
 (0)