Skip to content

Commit 01f05a8

Browse files
authored
Merge pull request #654 from Trusted-AI/development_issue_622
Introduce module art.evaluations and implement Security Curve evaluation
2 parents b5d8066 + af70c58 commit 01f05a8

File tree

6 files changed

+592
-0
lines changed

6 files changed

+592
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,4 @@ demo/pics/*
113113
!notebooks/*.ipynb
114114
!notebooks/adaptive_defence_evaluations/*.ipynb
115115
!notebooks/adversarial_patch/*.ipynb
116+
!notebooks/art_evaluations/*.ipynb

art/evaluations/__init__.py

Whitespace-only changes.

art/evaluations/evaluation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
19+
from abc import ABC, abstractmethod
20+
from typing import Any
21+
22+
23+
class Evaluation(ABC):
24+
@abstractmethod
25+
def evaluate(self, *args, **kwargs) -> Any:
26+
raise NotImplementedError
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
This module implements the evaluation of Security Curves.
3+
"""
4+
from art.evaluations.security_curve.security_curve import SecurityCurve
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements the evaluation of Security Curves.
20+
21+
Examples of Security Curves can be found in Figure 6 of Madry et al., 2017 (https://arxiv.org/abs/1706.06083).
22+
"""
23+
from typing import List, NoReturn, Tuple, TYPE_CHECKING, Union
24+
25+
import numpy as np
26+
from matplotlib import pyplot as plt
27+
28+
from art.evaluations.evaluation import Evaluation
29+
from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent import ProjectedGradientDescent
30+
31+
if TYPE_CHECKING:
32+
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE
33+
34+
35+
class SecurityCurve(Evaluation):
36+
"""
37+
This class implements the evaluation of Security Curves.
38+
39+
Examples of Security Curves can be found in Figure 6 of Madry et al., 2017 (https://arxiv.org/abs/1706.06083).
40+
"""
41+
42+
def __init__(self, eps: Union[int, List[float], List[int]]):
43+
"""
44+
Create an instance of a Security Curve evaluation.
45+
46+
:param eps: Defines the attack budgets `eps` for Projected Gradient Descent used for evaluation.
47+
"""
48+
49+
self.eps = eps
50+
self.eps_list = list()
51+
self.accuracy_adv_list = list()
52+
self.accuracy = None
53+
54+
def evaluate(
55+
self,
56+
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
57+
x: np.ndarray,
58+
y: np.ndarray,
59+
**kwargs: Union[str, bool, int, float]
60+
) -> Tuple[List[float], List[float], float]:
61+
"""
62+
Evaluate the Security Curve of a classifier using Projected Gradient Descent.
63+
64+
:param classifier: A trained classifier that provides loss gradients.
65+
:param x: Input data to classifier for evaluation.
66+
:param y: True labels for input data `x`.
67+
:param kwargs: Keyword arguments for the Projected Gradient Descent attack used for evaluation, except keywords
68+
`classifier` and `eps`.
69+
:return: List of evaluated `eps` values, List of adversarial accuracies, and benign accuracy.
70+
"""
71+
72+
kwargs.pop("classifier", None)
73+
kwargs.pop("eps", None)
74+
self.eps_list.clear()
75+
self.accuracy_adv_list.clear()
76+
self.accuracy = None
77+
78+
# Check type of eps
79+
if isinstance(self.eps, int):
80+
eps_increment = (classifier.clip_values[1] - classifier.clip_values[0]) / self.eps
81+
82+
for i in range(1, self.eps + 1):
83+
self.eps_list.append(i * eps_increment)
84+
85+
else:
86+
self.eps_list = self.eps.copy()
87+
88+
# Determine benign accuracy
89+
y_pred = classifier.predict(x=x, y=y)
90+
self.accuracy = self._get_accuracy(y=y, y_pred=y_pred)
91+
92+
# Determine adversarial accuracy for each eps
93+
for eps in self.eps_list:
94+
attack_pgd = ProjectedGradientDescent(estimator=classifier, eps=eps, **kwargs)
95+
96+
x_adv = attack_pgd.generate(x=x, y=y)
97+
98+
y_pred_adv = classifier.predict(x=x_adv, y=y)
99+
accuracy_adv = self._get_accuracy(y=y, y_pred=y_pred_adv)
100+
self.accuracy_adv_list.append(accuracy_adv)
101+
102+
# Check gradients for potential obfuscation
103+
self._check_gradient(classifier=classifier, x=x, y=y, **kwargs)
104+
105+
return self.eps_list, self.accuracy_adv_list, self.accuracy
106+
107+
@property
108+
def detected_obfuscating_gradients(self) -> bool:
109+
"""
110+
This property describes if the previous call to method `evaluate` identified potential gradient obfuscation.
111+
"""
112+
return self._is_obfuscating_gradients
113+
114+
def _check_gradient(
115+
self,
116+
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
117+
x: np.ndarray,
118+
y: np.ndarray,
119+
**kwargs: Union[str, bool, int, float]
120+
) -> NoReturn:
121+
"""
122+
Check if potential gradient obfuscation can be detected. Projected Gradient Descent with 100 iterations is run
123+
with maximum attack budget `eps` being equal to upper clip value of input data and `eps_step` of
124+
`eps / (max_iter / 2)`.
125+
126+
:param classifier: A trained classifier that provides loss gradients.
127+
:param x: Input data to classifier for evaluation.
128+
:param y: True labels for input data `x`.
129+
:param kwargs: Keyword arguments for the Projected Gradient Descent attack used for evaluation, except keywords
130+
`classifier` and `eps`.
131+
"""
132+
# Define parameters for Projected Gradient Descent
133+
max_iter = 100
134+
kwargs["max_iter"] = max_iter
135+
kwargs["eps"] = classifier.clip_values[1]
136+
kwargs["eps_step"] = classifier.clip_values[1] / (max_iter / 2)
137+
138+
# Create attack
139+
attack_pgd = ProjectedGradientDescent(estimator=classifier, **kwargs)
140+
141+
# Evaluate accuracy with maximal attack budget
142+
x_adv = attack_pgd.generate(x=x, y=y)
143+
y_pred_adv = classifier.predict(x=x_adv, y=y)
144+
accuracy_adv = self._get_accuracy(y=y, y_pred=y_pred_adv)
145+
146+
# Decide of obfuscated gradients likely
147+
if accuracy_adv > 1 / classifier.nb_classes:
148+
self._is_obfuscating_gradients = True
149+
else:
150+
self._is_obfuscating_gradients = False
151+
152+
def plot(self) -> NoReturn:
153+
"""
154+
Plot the Security Curve of adversarial accuracy as function opf attack budget `eps` together with the accuracy
155+
on benign samples.
156+
"""
157+
plt.plot(self.eps_list, self.accuracy_adv_list, label="adversarial", marker="o")
158+
plt.plot([self.eps_list[0], self.eps_list[-1]], [self.accuracy, self.accuracy], linestyle="--", label="benign")
159+
plt.legend()
160+
plt.xlabel("Attack budget eps")
161+
plt.ylabel("Accuracy")
162+
if self.is_obfuscating_gradients:
163+
plt.title("Potential gradient obfuscation detected.")
164+
else:
165+
plt.title("No gradient obfuscation detected")
166+
plt.ylim([0, 1.05])
167+
plt.show()
168+
169+
@staticmethod
170+
def _get_accuracy(y: np.ndarray, y_pred: np.ndarray) -> float:
171+
"""
172+
Calculate accuracy of predicted labels.
173+
174+
:param y: True labels.
175+
:param y_pred: Predicted labels.
176+
:return: Accuracy.
177+
"""
178+
return np.mean(np.argmax(y, axis=1) == np.argmax(y_pred, axis=1)).item()
179+
180+
def __repr__(self):
181+
repr_ = "{}(eps={})".format(self.__module__ + "." + self.__class__.__name__, self.eps,)
182+
return repr_

0 commit comments

Comments
 (0)