|
1 | | -# MIT License |
2 | | -# |
3 | | -# Copyright (C) IBM Corporation 2018 |
4 | | -# |
5 | | -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
6 | | -# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
7 | | -# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
8 | | -# persons to whom the Software is furnished to do so, subject to the following conditions: |
9 | | -# |
10 | | -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
11 | | -# Software. |
12 | | -# |
13 | | -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
14 | | -# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
15 | | -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
16 | | -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
17 | | -# SOFTWARE. |
18 | 1 | """ |
19 | 2 | Module implementing varying metrics for assessing model robustness. These fall mainly under two categories: |
20 | 3 | attack-dependent and attack-independent. |
|
25 | 8 | import numpy as np |
26 | 9 | import numpy.linalg as la |
27 | 10 | import tensorflow as tf |
| 11 | +from scipy.stats import weibull_min |
| 12 | +from scipy.optimize import fmin as scipy_optimizer |
| 13 | +from scipy.special import gammainc |
| 14 | +from functools import reduce |
28 | 15 |
|
29 | 16 | from art.attacks.fast_gradient import FastGradientMethod |
30 | 17 |
|
@@ -196,3 +183,187 @@ def loss_sensitivity(x, classifier, sess): |
196 | 183 | res = la.norm(res.reshape(res.shape[0], -1), ord=2, axis=1) |
197 | 184 |
|
198 | 185 | return np.mean(res) |
| 186 | + |
| 187 | + |
| 188 | +def clever_u(x, classifier, n_b, n_s, r, sess, c_init=1): |
| 189 | + """ |
| 190 | + Compute CLEVER score for an untargeted attack. Paper link: https://arxiv.org/abs/1801.10578 |
| 191 | +
|
| 192 | + :param x: One input sample |
| 193 | + :type x: `np.ndarray` |
| 194 | + :param classifier: A trained model. |
| 195 | + :type classifier: :class:`Classifier` |
| 196 | + :param n_b: Batch size |
| 197 | + :type n_b: `int` |
| 198 | + :param n_s: Number of examples per batch |
| 199 | + :type n_s: `int` |
| 200 | + :param r: Maximum perturbation |
| 201 | + :type r: `float` |
| 202 | + :param sess: The session to run graphs in |
| 203 | + :type sess: `tf.Session` |
| 204 | + :param c_init: initialization of Weibull distribution |
| 205 | + :type c_init: `float` |
| 206 | + :return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf |
| 207 | + :rtype: `tuple` |
| 208 | + """ |
| 209 | + # Get a list of untargeted classes |
| 210 | + y_pred = classifier.predict(np.array([x])) |
| 211 | + pred_class = np.argmax(y_pred, axis=1)[0] |
| 212 | + num_class = np.shape(y_pred)[1] |
| 213 | + untarget_classes = [i for i in range(num_class) if i != pred_class] |
| 214 | + |
| 215 | + # Compute CLEVER score for each untargeted class |
| 216 | + score1_list, score2_list, score8_list = [], [], [] |
| 217 | + for j in untarget_classes: |
| 218 | + s1, s2, s8 = clever_t(x, classifier, j, n_b, n_s, r, sess, c_init) |
| 219 | + score1_list.append(s1) |
| 220 | + score2_list.append(s2) |
| 221 | + score8_list.append(s8) |
| 222 | + |
| 223 | + return np.min(score1_list), np.min(score2_list), np.min(score8_list) |
| 224 | + |
| 225 | + |
| 226 | +def clever_t(x, classifier, target_class, n_b, n_s, r, sess, c_init=1): |
| 227 | + """ |
| 228 | + Compute CLEVER score for a targeted attack. Paper link: https://arxiv.org/abs/1801.10578 |
| 229 | +
|
| 230 | + :param x: One input sample |
| 231 | + :type x: `np.ndarray` |
| 232 | + :param classifier: A trained model |
| 233 | + :type classifier: :class:`Classifier` |
| 234 | + :param target_class: Targeted class |
| 235 | + :type target_class: `int` |
| 236 | + :param n_b: Batch size |
| 237 | + :type n_b: `int` |
| 238 | + :param n_s: Number of examples per batch |
| 239 | + :type n_s: `int` |
| 240 | + :param r: Maximum perturbation |
| 241 | + :type r: `float` |
| 242 | + :param sess: The session to run graphs in |
| 243 | + :type sess: `tf.Session` |
| 244 | + :param c_init: Initialization of Weibull distribution |
| 245 | + :type c_init: `float` |
| 246 | + :return: A tuple of 3 CLEVER scores, corresponding to norms 1, 2 and np.inf |
| 247 | + :rtype: `tuple` |
| 248 | + """ |
| 249 | + # Check if the targeted class is different from the predicted class |
| 250 | + y_pred = classifier.predict(np.array([x])) |
| 251 | + pred_class = np.argmax(y_pred, axis=1)[0] |
| 252 | + if target_class == pred_class: |
| 253 | + raise ValueError("The targeted class is the predicted class!") |
| 254 | + |
| 255 | + # Define placeholders for computing g gradients |
| 256 | + shape = [None] |
| 257 | + shape.extend(x.shape) |
| 258 | + imgs = tf.placeholder(shape=shape, dtype=tf.float32) |
| 259 | + pred_class_ph = tf.placeholder(dtype=tf.int32, shape=[]) |
| 260 | + target_class_ph = tf.placeholder(dtype=tf.int32, shape=[]) |
| 261 | + |
| 262 | + # Define tensors for g gradients |
| 263 | + grad_norm_1, grad_norm_2, grad_norm_8, g_x = _build_g_gradient(imgs, classifier, pred_class_ph, target_class_ph) |
| 264 | + |
| 265 | + # Some auxiliary vars |
| 266 | + set1, set2, set8 = [], [], [] |
| 267 | + dim = reduce(lambda x_, y: x_ * y, x.shape, 1) |
| 268 | + shape = [n_s] |
| 269 | + shape.extend(x.shape) |
| 270 | + |
| 271 | + # Compute predicted class |
| 272 | + y_pred = classifier.predict(np.array([x])) |
| 273 | + pred_class = np.argmax(y_pred, axis=1)[0] |
| 274 | + |
| 275 | + # Loop over n_b batches |
| 276 | + for i in range(n_b): |
| 277 | + # Random generation of data points |
| 278 | + sample_xs0 = np.reshape(_random_sphere(m=n_s, n=dim, r=r), shape) |
| 279 | + sample_xs = sample_xs0 + np.repeat(np.array([x]), n_s, 0) |
| 280 | + np.clip(sample_xs, 0, 1, out=sample_xs) |
| 281 | + |
| 282 | + # Preprocess data if it is supported in the classifier |
| 283 | + if hasattr(classifier, 'feature_squeeze'): |
| 284 | + sample_xs = classifier.feature_squeeze(sample_xs) |
| 285 | + sample_xs = classifier._preprocess(sample_xs) |
| 286 | + |
| 287 | + # Compute gradients |
| 288 | + max_gn1, max_gn2, max_gn8 = sess.run( |
| 289 | + [grad_norm_1, grad_norm_2, grad_norm_8], |
| 290 | + feed_dict={imgs: sample_xs, pred_class_ph: pred_class, |
| 291 | + target_class_ph: target_class}) |
| 292 | + set1.append(max_gn1) |
| 293 | + set2.append(max_gn2) |
| 294 | + set8.append(max_gn8) |
| 295 | + |
| 296 | + # Maximum likelihood estimation for max gradient norms |
| 297 | + [_, loc1, _] = weibull_min.fit(-np.array(set1), c_init, optimizer=scipy_optimizer) |
| 298 | + [_, loc2, _] = weibull_min.fit(-np.array(set2), c_init, optimizer=scipy_optimizer) |
| 299 | + [_, loc8, _] = weibull_min.fit(-np.array(set8), c_init, optimizer=scipy_optimizer) |
| 300 | + |
| 301 | + # Compute g_x0 |
| 302 | + x0 = np.array([x]) |
| 303 | + if hasattr(classifier, 'feature_squeeze'): |
| 304 | + x0 = classifier.feature_squeeze(x0) |
| 305 | + x0 = classifier._preprocess(x0) |
| 306 | + g_x0 = sess.run(g_x, feed_dict={imgs: x0, pred_class_ph: pred_class, |
| 307 | + target_class_ph: target_class}) |
| 308 | + |
| 309 | + # Compute scores |
| 310 | + # Note q = p / (p-1) |
| 311 | + s8 = np.min([-g_x0[0] / loc1, r]) |
| 312 | + s2 = np.min([-g_x0[0] / loc2, r]) |
| 313 | + s1 = np.min([-g_x0[0] / loc8, r]) |
| 314 | + |
| 315 | + return s1, s2, s8 |
| 316 | + |
| 317 | + |
| 318 | +def _build_g_gradient(x, classifier, pred_class, target_class): |
| 319 | + """ |
| 320 | + Build tensors of gradient `g`. |
| 321 | +
|
| 322 | + :param x: One input sample |
| 323 | + :type x: `np.ndarray` |
| 324 | + :param classifier: A trained model |
| 325 | + :type classifier: :class:`Classifier` |
| 326 | + :param pred_class: Predicted class |
| 327 | + :type pred_class: `int` |
| 328 | + :param target_class: Target class |
| 329 | + :type target_class: `int` |
| 330 | + :return: Max gradient norms |
| 331 | + :rtype: `tuple` |
| 332 | + """ |
| 333 | + # Get predict values |
| 334 | + y_pred = classifier.model(x) |
| 335 | + pred_val = y_pred[:, pred_class] |
| 336 | + target_val = y_pred[:, target_class] |
| 337 | + g_x = pred_val - target_val |
| 338 | + |
| 339 | + # Get the gradient op |
| 340 | + grad_op = tf.gradients(g_x, x)[0] |
| 341 | + |
| 342 | + # Compute the gradient norm |
| 343 | + grad_op_rs = tf.reshape(grad_op, (tf.shape(grad_op)[0], -1)) |
| 344 | + grad_norm_1 = tf.reduce_max(tf.norm(grad_op_rs, ord=1, axis=1)) |
| 345 | + grad_norm_2 = tf.reduce_max(tf.norm(grad_op_rs, ord=2, axis=1)) |
| 346 | + grad_norm_8 = tf.reduce_max(tf.norm(grad_op_rs, ord=np.inf, axis=1)) |
| 347 | + |
| 348 | + return grad_norm_1, grad_norm_2, grad_norm_8, g_x |
| 349 | + |
| 350 | + |
| 351 | +def _random_sphere(m, n, r): |
| 352 | + """ |
| 353 | + Generate randomly `m x n`-dimension points with radius `r` and centered around 0. |
| 354 | +
|
| 355 | + :param m: Number of random data points |
| 356 | + :type m: `int` |
| 357 | + :param n: Dimension |
| 358 | + :type n: `int` |
| 359 | + :param r: Radius |
| 360 | + :type r: `float` |
| 361 | + :return: The generated random sphere |
| 362 | + :rtype: `np.ndarray` |
| 363 | + """ |
| 364 | + a = np.random.randn(m, n) |
| 365 | + s2 = np.sum(a**2, axis=1) |
| 366 | + base = gammainc(n/2, s2/2)**(1/n) * r / np.sqrt(s2) |
| 367 | + a = a * (np.tile(base, (n, 1))).T |
| 368 | + |
| 369 | + return a |
0 commit comments