Skip to content

Commit 28397f5

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Add CLEVER metric
2 parents f687a9d + 7dbc91a commit 28397f5

File tree

4 files changed

+312
-36
lines changed

4 files changed

+312
-36
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ install:
5353

5454
script:
5555
- mkdir ./data
56-
- python -m unittest discover src/ -p '*_unittest.py'
56+
- python -m unittest discover art/ -p '*_unittest.py'

art/metrics.py

Lines changed: 188 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,3 @@
1-
# MIT License
2-
#
3-
# Copyright (C) IBM Corporation 2018
4-
#
5-
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6-
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7-
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8-
# persons to whom the Software is furnished to do so, subject to the following conditions:
9-
#
10-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11-
# Software.
12-
#
13-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14-
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16-
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17-
# SOFTWARE.
181
"""
192
Module implementing varying metrics for assessing model robustness. These fall mainly under two categories:
203
attack-dependent and attack-independent.
@@ -25,6 +8,10 @@
258
import numpy as np
269
import numpy.linalg as la
2710
import tensorflow as tf
11+
from scipy.stats import weibull_min
12+
from scipy.optimize import fmin as scipy_optimizer
13+
from scipy.special import gammainc
14+
from functools import reduce
2815

2916
from art.attacks.fast_gradient import FastGradientMethod
3017

@@ -196,3 +183,187 @@ def loss_sensitivity(x, classifier, sess):
196183
res = la.norm(res.reshape(res.shape[0], -1), ord=2, axis=1)
197184

198185
return np.mean(res)
186+
187+
188+
def clever_u(x, classifier, n_b, n_s, r, sess, c_init=1):
189+
"""
190+
Compute CLEVER score for an untargeted attack. Paper link: https://arxiv.org/abs/1801.10578
191+
192+
:param x: One input sample
193+
:type x: `np.ndarray`
194+
:param classifier: A trained model.
195+
:type classifier: :class:`Classifier`
196+
:param n_b: Batch size
197+
:type n_b: `int`
198+
:param n_s: Number of examples per batch
199+
:type n_s: `int`
200+
:param r: Maximum perturbation
201+
:type r: `float`
202+
:param sess: The session to run graphs in
203+
:type sess: `tf.Session`
204+
:param c_init: initialization of Weibull distribution
205+
:type c_init: `float`
206+
:return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf
207+
:rtype: `tuple`
208+
"""
209+
# Get a list of untargeted classes
210+
y_pred = classifier.predict(np.array([x]))
211+
pred_class = np.argmax(y_pred, axis=1)[0]
212+
num_class = np.shape(y_pred)[1]
213+
untarget_classes = [i for i in range(num_class) if i != pred_class]
214+
215+
# Compute CLEVER score for each untargeted class
216+
score1_list, score2_list, score8_list = [], [], []
217+
for j in untarget_classes:
218+
s1, s2, s8 = clever_t(x, classifier, j, n_b, n_s, r, sess, c_init)
219+
score1_list.append(s1)
220+
score2_list.append(s2)
221+
score8_list.append(s8)
222+
223+
return np.min(score1_list), np.min(score2_list), np.min(score8_list)
224+
225+
226+
def clever_t(x, classifier, target_class, n_b, n_s, r, sess, c_init=1):
227+
"""
228+
Compute CLEVER score for a targeted attack. Paper link: https://arxiv.org/abs/1801.10578
229+
230+
:param x: One input sample
231+
:type x: `np.ndarray`
232+
:param classifier: A trained model
233+
:type classifier: :class:`Classifier`
234+
:param target_class: Targeted class
235+
:type target_class: `int`
236+
:param n_b: Batch size
237+
:type n_b: `int`
238+
:param n_s: Number of examples per batch
239+
:type n_s: `int`
240+
:param r: Maximum perturbation
241+
:type r: `float`
242+
:param sess: The session to run graphs in
243+
:type sess: `tf.Session`
244+
:param c_init: Initialization of Weibull distribution
245+
:type c_init: `float`
246+
:return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf
247+
:rtype: `tuple`
248+
"""
249+
# Check if the targeted class is different from the predicted class
250+
y_pred = classifier.predict(np.array([x]))
251+
pred_class = np.argmax(y_pred, axis=1)[0]
252+
if target_class == pred_class:
253+
raise ValueError("The targeted class is the predicted class!")
254+
255+
# Define placeholders for computing g gradients
256+
shape = [None]
257+
shape.extend(x.shape)
258+
imgs = tf.placeholder(shape=shape, dtype=tf.float32)
259+
pred_class_ph = tf.placeholder(dtype=tf.int32, shape=[])
260+
target_class_ph = tf.placeholder(dtype=tf.int32, shape=[])
261+
262+
# Define tensors for g gradients
263+
grad_norm_1, grad_norm_2, grad_norm_8, g_x = _build_g_gradient(imgs, classifier, pred_class_ph, target_class_ph)
264+
265+
# Some auxiliary vars
266+
set1, set2, set8 = [], [], []
267+
dim = reduce(lambda x_, y: x_ * y, x.shape, 1)
268+
shape = [n_s]
269+
shape.extend(x.shape)
270+
271+
# Compute predicted class
272+
y_pred = classifier.predict(np.array([x]))
273+
pred_class = np.argmax(y_pred, axis=1)[0]
274+
275+
# Loop over n_b batches
276+
for i in range(n_b):
277+
# Random generation of data points
278+
sample_xs0 = np.reshape(_random_sphere(m=n_s, n=dim, r=r), shape)
279+
sample_xs = sample_xs0 + np.repeat(np.array([x]), n_s, 0)
280+
np.clip(sample_xs, 0, 1, out=sample_xs)
281+
282+
# Preprocess data if it is supported in the classifier
283+
if hasattr(classifier, 'feature_squeeze'):
284+
sample_xs = classifier.feature_squeeze(sample_xs)
285+
sample_xs = classifier._preprocess(sample_xs)
286+
287+
# Compute gradients
288+
max_gn1, max_gn2, max_gn8 = sess.run(
289+
[grad_norm_1, grad_norm_2, grad_norm_8],
290+
feed_dict={imgs: sample_xs, pred_class_ph: pred_class,
291+
target_class_ph: target_class})
292+
set1.append(max_gn1)
293+
set2.append(max_gn2)
294+
set8.append(max_gn8)
295+
296+
# Maximum likelihood estimation for max gradient norms
297+
[_, loc1, _] = weibull_min.fit(-np.array(set1), c_init, optimizer=scipy_optimizer)
298+
[_, loc2, _] = weibull_min.fit(-np.array(set2), c_init, optimizer=scipy_optimizer)
299+
[_, loc8, _] = weibull_min.fit(-np.array(set8), c_init, optimizer=scipy_optimizer)
300+
301+
# Compute g_x0
302+
x0 = np.array([x])
303+
if hasattr(classifier, 'feature_squeeze'):
304+
x0 = classifier.feature_squeeze(x0)
305+
x0 = classifier._preprocess(x0)
306+
g_x0 = sess.run(g_x, feed_dict={imgs: x0, pred_class_ph: pred_class,
307+
target_class_ph: target_class})
308+
309+
# Compute scores
310+
# Note q = p / (p-1)
311+
s8 = np.min([-g_x0[0] / loc1, r])
312+
s2 = np.min([-g_x0[0] / loc2, r])
313+
s1 = np.min([-g_x0[0] / loc8, r])
314+
315+
return s1, s2, s8
316+
317+
318+
def _build_g_gradient(x, classifier, pred_class, target_class):
319+
"""
320+
Build tensors of gradient `g`.
321+
322+
:param x: One input sample
323+
:type x: `np.ndarray`
324+
:param classifier: A trained model
325+
:type classifier: :class:`Classifier`
326+
:param pred_class: Predicted class
327+
:type pred_class: `int`
328+
:param target_class: Target class
329+
:type target_class: `int`
330+
:return: Max gradient norms
331+
:rtype: `tuple`
332+
"""
333+
# Get predict values
334+
y_pred = classifier.model(x)
335+
pred_val = y_pred[:, pred_class]
336+
target_val = y_pred[:, target_class]
337+
g_x = pred_val - target_val
338+
339+
# Get the gradient op
340+
grad_op = tf.gradients(g_x, x)[0]
341+
342+
# Compute the gradient norm
343+
grad_op_rs = tf.reshape(grad_op, (tf.shape(grad_op)[0], -1))
344+
grad_norm_1 = tf.reduce_max(tf.norm(grad_op_rs, ord=1, axis=1))
345+
grad_norm_2 = tf.reduce_max(tf.norm(grad_op_rs, ord=2, axis=1))
346+
grad_norm_8 = tf.reduce_max(tf.norm(grad_op_rs, ord=np.inf, axis=1))
347+
348+
return grad_norm_1, grad_norm_2, grad_norm_8, g_x
349+
350+
351+
def _random_sphere(m, n, r):
352+
"""
353+
Generate randomly `m x n`-dimension points with radius `r` and centered around 0.
354+
355+
:param m: Number of random data points
356+
:type m: `int`
357+
:param n: Dimension
358+
:type n: `int`
359+
:param r: Radius
360+
:type r: `float`
361+
:return: The generated random sphere
362+
:rtype: `np.ndarray`
363+
"""
364+
a = np.random.randn(m, n)
365+
s2 = np.sum(a**2, axis=1)
366+
base = gammainc(n/2, s2/2)**(1/n) * r / np.sqrt(s2)
367+
a = a * (np.tile(base, (n, 1))).T
368+
369+
return a

art/metrics_unittest.py

Lines changed: 118 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,3 @@
1-
# MIT License
2-
#
3-
# Copyright (C) IBM Corporation 2018
4-
#
5-
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6-
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7-
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8-
# persons to whom the Software is furnished to do so, subject to the following conditions:
9-
#
10-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11-
# Software.
12-
#
13-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14-
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16-
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17-
# SOFTWARE.
18-
191
from __future__ import absolute_import, division, print_function, unicode_literals
202

213
import unittest
@@ -27,6 +9,7 @@
279
from art.classifiers.cnn import CNN
2810
from art.metrics import empirical_robustness
2911
from art.utils import load_mnist, load_cifar10
12+
from art.metrics import clever_t, clever_u
3013
from art.classifiers.classifier import Classifier
3114

3215
BATCH_SIZE = 10
@@ -100,5 +83,122 @@ def test_emp_robustness_mnist(self):
10083
# self.assertLessEqual(emp_robust_jsma, 1.)
10184

10285

86+
#########################################
87+
# This part is the unit test for Clever.#
88+
#########################################
89+
90+
class TestClassifier(Classifier):
91+
def __init__(self, defences=None, preproc=None):
92+
from keras.models import Sequential
93+
from keras.layers import Lambda
94+
model = Sequential(name="TestClassifier")
95+
model.add(Lambda(lambda x: x + 0, input_shape=(2,)))
96+
97+
super(TestClassifier, self).__init__(model, defences, preproc)
98+
99+
100+
class TestClever(unittest.TestCase):
101+
"""
102+
Unittest for Clever metrics.
103+
"""
104+
def test_clever_t_unit(self):
105+
"""
106+
Test the targeted version with simplified data.
107+
:return:
108+
"""
109+
print("Unit test for the targeted version with simplified data.")
110+
# Define session & params
111+
session = tf.Session()
112+
k.set_session(session)
113+
114+
# Get classifier
115+
classifier = TestClassifier()
116+
117+
# Compute scores
118+
res = clever_t(np.array([1, 0]), classifier, 1, 20, 10, 1, session)
119+
120+
# Test
121+
self.assertAlmostEqual(res[0], 0.9999999999999998, delta=0.00001)
122+
self.assertAlmostEqual(res[1], 0.7071067811865474, delta=0.00001)
123+
self.assertAlmostEqual(res[2], 0.4999999999999999, delta=0.00001)
124+
125+
def test_clever_u_unit(self):
126+
"""
127+
Test the untargeted version with simplified data.
128+
:return:
129+
"""
130+
print("Unit test for the untargeted version with simplified data.")
131+
# Define session & params
132+
session = tf.Session()
133+
k.set_session(session)
134+
135+
# Get classifier
136+
classifier = TestClassifier()
137+
138+
# Compute scores
139+
res = clever_u(np.array([1, 0]), classifier, 20, 10, 1, session)
140+
141+
# Test
142+
self.assertAlmostEqual(res[0], 0.9999999999999998, delta=0.00001)
143+
self.assertAlmostEqual(res[1], 0.7071067811865474, delta=0.00001)
144+
self.assertAlmostEqual(res[2], 0.4999999999999999, delta=0.00001)
145+
146+
def test_clever_t(self):
147+
"""
148+
Test the targeted version.
149+
:return:
150+
"""
151+
print("Test if the targeted version works on a true classifier/data")
152+
# Define session & params
153+
session = tf.Session()
154+
k.set_session(session)
155+
156+
comp_params = {"loss": 'categorical_crossentropy', "optimizer": 'adam',
157+
"metrics": ['accuracy']}
158+
159+
# Get MNIST
160+
(X_train, Y_train), (_, _), _, _ = load_mnist()
161+
X_train, Y_train = X_train[:NB_TRAIN], Y_train[:NB_TRAIN]
162+
im_shape = X_train[0].shape
163+
164+
# Get classifier
165+
classifier = CNN(im_shape, act="relu")
166+
classifier.compile(comp_params)
167+
classifier.fit(X_train, Y_train, epochs=1,
168+
batch_size=BATCH_SIZE, verbose=0)
169+
170+
res = clever_t(X_train[-1], classifier, 7, 20, 10, 5, session)
171+
self.assertGreater(res[0], res[1])
172+
self.assertGreater(res[1], res[2])
173+
174+
def test_clever_u(self):
175+
"""
176+
Test the untargeted version.
177+
:return:
178+
"""
179+
print("Test if the untargeted version works on a true classifier/data")
180+
# Define session & params
181+
session = tf.Session()
182+
k.set_session(session)
183+
184+
comp_params = {"loss": 'categorical_crossentropy', "optimizer": 'adam',
185+
"metrics": ['accuracy']}
186+
187+
# Get MNIST
188+
(X_train, Y_train), (_, _), _, _ = load_mnist()
189+
X_train, Y_train = X_train[:NB_TRAIN], Y_train[:NB_TRAIN]
190+
im_shape = X_train[0].shape
191+
192+
# Get classifier
193+
classifier = CNN(im_shape, act="relu")
194+
classifier.compile(comp_params)
195+
classifier.fit(X_train, Y_train, epochs=1,
196+
batch_size=BATCH_SIZE, verbose=0)
197+
198+
res = clever_u(X_train[-1], classifier, 2, 10, 5, session)
199+
self.assertGreater(res[0], res[1])
200+
self.assertGreater(res[1], res[2])
201+
202+
103203
if __name__ == '__main__':
104204
unittest.main()

docs/modules/metrics.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,8 @@ Empirical Robustness
1313
Distance to nearest neighbors
1414
-----------------------------
1515
.. autofunction:: nearest_neighbour_dist
16+
17+
CLEVER
18+
------
19+
.. autofunction:: clever_u
20+
.. autofunction:: clever_t

0 commit comments

Comments
 (0)