|
1 | | -from __future__ import absolute_import, division, print_function |
| 1 | +from __future__ import absolute_import, division, print_function, unicode_literals |
2 | 2 |
|
3 | | -from config import config_dict |
4 | | - |
5 | | -from cleverhans.attacks_tf import vatm |
6 | | -from keras import backend as k |
| 3 | +import numpy as np |
7 | 4 | import tensorflow as tf |
8 | 5 |
|
9 | | -from src.attacks.attack import Attack |
| 6 | +from src.attacks.attack import Attack, class_derivative |
10 | 7 |
|
11 | 8 |
|
12 | 9 | class VirtualAdversarialMethod(Attack): |
13 | 10 | """ |
14 | 11 | This attack was originally proposed by Miyato et al. (2016) and was used for virtual adversarial training. |
15 | 12 | Paper link: https://arxiv.org/abs/1507.00677 |
16 | 13 | """ |
17 | | - attack_params = ['max_iter', 'xi', 'clip_min', 'clip_max'] |
| 14 | + attack_params = ['eps', 'finite_diff', 'max_iter', 'clip_min', 'clip_max'] |
18 | 15 |
|
19 | | - def __init__(self, classifier, sess=None, max_iter=5, xi=1e-6, clip_min=None, clip_max=None): |
| 16 | + def __init__(self, classifier, sess=None, max_iter=1, finite_diff=1e-6, eps=.1, clip_min=0., clip_max=1.): |
20 | 17 | """ |
21 | 18 | Create a VirtualAdversarialMethod instance. |
| 19 | +
|
22 | 20 | :param classifier: A function that takes a symbolic input and returns the symbolic output for the classifier's |
23 | 21 | predictions. |
24 | | - :param sess: The tf session to run graphs in. |
25 | | - :param max_iter: (optional integer) The maximum number of iterations. |
26 | | - :param xi: (optional float) The finite difference parameter. |
| 22 | + :param sess: The tf session to run graphs in |
| 23 | + :param eps: (optional float) the epsilon (max input variation parameter) |
| 24 | + :param finite_diff: (optional float) The finite difference parameter |
| 25 | + :param max_iter: (optional integer) The maximum number of iterations |
27 | 26 | :param clip_min: (optional float) Minimum input component value |
28 | 27 | :param clip_max: (optional float) Maximum input component value |
29 | 28 | """ |
30 | 29 | super(VirtualAdversarialMethod, self).__init__(classifier, sess) |
31 | 30 |
|
32 | | - kwargs = {'max_iter': max_iter, 'xi': xi, 'clip_min': clip_min, 'clip_max': clip_max} |
| 31 | + kwargs = {'finite_diff': finite_diff, 'eps': eps, 'max_iter': max_iter, 'clip_min': clip_min, 'clip_max': clip_max} |
33 | 32 | self.set_params(**kwargs) |
34 | 33 |
|
35 | | - def generate_graph(self, x, eps=0.1, **kwargs): |
| 34 | + def generate(self, x_val, **kwargs): |
36 | 35 | """ |
37 | | - Generate symbolic graph for adversarial examples and return. |
38 | | - :param x: The model's symbolic inputs. |
| 36 | + Generate adversarial samples and return them in a Numpy array. |
| 37 | +
|
| 38 | + :param x_val: (required) A Numpy array with the original inputs |
39 | 39 | :param eps: (optional float) the epsilon (max input variation parameter) |
40 | | - :param max_iter: (optional integer) The maximum number of iterations. |
41 | | - :param xi: (optional float) The finite difference parameter. |
| 40 | + :param finite_diff: (optional float) The finite difference parameter |
| 41 | + :param max_iter: (optinal integer) The maximum number of iterations |
42 | 42 | :param clip_min: (optional float) Minimum input component value |
43 | 43 | :param clip_max: (optional float) Maximum input component value |
| 44 | + :return: A Numpy array holding the adversarial examples |
| 45 | + :rtype: np.ndarray |
44 | 46 | """ |
| 47 | + # TODO Consider computing attack for a batch of samples at a time (no for loop) |
45 | 48 | # Parse and save attack-specific parameters |
46 | 49 | assert self.set_params(**kwargs) |
47 | 50 |
|
48 | | - return vatm(self.classifier, x, self.classifier._get_predictions(x, log=False), eps=eps, |
49 | | - num_iterations=self.max_iter, xi=self.xi, clip_min=self.clip_min, clip_max=self.clip_max) |
| 51 | + x_adv = np.copy(x_val) |
| 52 | + dims = [None] + list(x_val.shape[1:]) |
| 53 | + self._x = tf.placeholder(tf.float32, shape=dims) |
| 54 | + dims[0] = 1 |
| 55 | + self._preds = self.classifier._get_predictions(self._x, log=False) |
| 56 | + preds_val = self.sess.run(self._preds, {self._x: x_adv}) |
| 57 | + |
| 58 | + for ind, val in enumerate(x_adv): |
| 59 | + d = np.random.randn(*dims[1:]) |
| 60 | + e = np.random.randn(*dims[1:]) |
| 61 | + for _ in range(self.max_iter): |
| 62 | + d = self.finite_diff * self._normalize(d) |
| 63 | + e = self.finite_diff * self._normalize(e) |
| 64 | + preds_val_d = self.sess.run(self._preds, {self._x: [val + d]})[0] |
| 65 | + preds_val_e = self.sess.run(self._preds, {self._x: [val + e]})[0] |
| 66 | + |
| 67 | + # Compute KL divergence between logits |
| 68 | + from scipy.stats import entropy |
| 69 | + kl_div1 = entropy(preds_val[ind], preds_val_d) |
| 70 | + kl_div2 = entropy(preds_val[ind], preds_val_e) |
| 71 | + d = (kl_div1 - kl_div2) / np.abs(d - e) |
| 72 | + |
| 73 | + # Apply perturbation and clip |
| 74 | + val += self.eps * self._normalize(d) |
| 75 | + if self.clip_min is not None or self.clip_max is not None: |
| 76 | + val = np.clip(val, self.clip_min, self.clip_max) |
50 | 77 |
|
51 | | - def generate(self, x_val, eps=0.1, **kwargs): |
| 78 | + return x_adv |
| 79 | + |
| 80 | + def _normalize(self, x): |
52 | 81 | """ |
53 | | - Generate adversarial samples and return them in a Numpy array. |
54 | | - :param x_val: (required) A Numpy array with the original inputs. |
55 | | - :param eps: (optional float) the epsilon (max input variation parameter) |
56 | | - :param max_iter: (optinal integer) The maximum number of iterations. |
57 | | - :param xi: (optional float) The finite difference parameter |
58 | | - :param clip_min: (optional float) Minimum input component value |
59 | | - :param clip_max: (optional float) Maximum input component value |
| 82 | + Apply L_2 batch normalization on `x`. |
| 83 | +
|
| 84 | + :param x: (np.ndarray) The input array to normalize |
| 85 | + :return: The nornmalized version of `x` |
| 86 | + :rtype: np.ndarray |
60 | 87 | """ |
61 | | - # Generate this attack's graph if it hasn't been done previously |
62 | | - input_shape = list(x_val.shape) |
63 | | - input_shape[0] = None |
64 | | - self._x = tf.placeholder(tf.float32, shape=input_shape) |
65 | | - self._x_adv = self.generate_graph(self._x, eps, **kwargs) |
| 88 | + tol = 1e-12 |
| 89 | + dims = x.shape |
66 | 90 |
|
67 | | - return self.sess.run(self._x_adv, feed_dict={self._x: x_val, k.learning_phase(): 0}) |
| 91 | + x = x.flatten() |
| 92 | + x /= np.max(np.abs(x)) + tol |
| 93 | + inverse = (np.sum(x**2) + np.sqrt(tol)) ** -.5 |
| 94 | + x = x * inverse |
| 95 | + x = np.reshape(x, dims) |
| 96 | + |
| 97 | + return x |
68 | 98 |
|
69 | 99 | def set_params(self, **kwargs): |
70 | 100 | """ |
71 | 101 | Take in a dictionary of parameters and applies attack-specific checks before saving them as attributes. |
72 | 102 |
|
73 | 103 | Attack-specific parameters: |
74 | | - :param max_iter: (optional integer) The maximum number of iterations. |
75 | | - :param xi: (optional float) The finite difference parameter |
| 104 | + :param eps: (optional float) the epsilon (max input variation parameter) |
| 105 | + :param finite_diff: (optional float) The finite difference parameter |
| 106 | + :param max_iter: (optional integer) The maximum number of iterations |
76 | 107 | :param clip_min: (optional float) Minimum input component value |
77 | 108 | :param clip_max: (optional float) Maximum input component value |
78 | 109 | """ |
|
0 commit comments