Skip to content

Commit 40f435c

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Merge branch 'dev'
Changes: - new defenses: JPEG compression, total variance minimization - optimization of NewtonFool attack to run batches - small optimizations/ bug fix in Carlini attack - small optimizations in DeepFool - changes to compute class gradients for a list of targets for all classifiers - minor changes in visualization module and utils - updates in docs, examples and notebooks Conflicts: README.md art/utils.py notebooks/mnist_poisoning_demo.ipynb
2 parents 484cadd + d1966db commit 40f435c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1661
-481
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@ matrix:
1010
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.6.0
1111
- python: 2.7
1212
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.7.0
13+
- python: 2.7
14+
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.10.0
1315
- python: 3.5
1416
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.5.0
1517
- python: 3.5
1618
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.6.0
1719
- python: 3.5
1820
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.7.0
21+
- python: 3.5
22+
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.10.0
1923
exclude:
2024
- env:
2125

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ The following **defence** methods are also supported:
2626
* Virtual adversarial training ([Miyato et al., 2015](https://arxiv.org/abs/1507.00677))
2727
* Gaussian data augmentation ([Zantedeschi et al., 2017](https://arxiv.org/abs/1707.06728))
2828
* Thermometer encoding ([Buckman et al., 2018](https://openreview.net/forum?id=S18Su--CW))
29+
* Total variance minimization ([Guo et al., 2018](https://openreview.net/forum?id=SyJ7ClWCb))
30+
* JPEG compression ([Dziugaite et al., 2016](https://arxiv.org/abs/1608.00853))
2931

3032
ART also implements **detection** methods of adversarial samples:
3133
* Basic detector based on inputs
3234
* Detector trained on the activations of a specific layer
3335

3436
The following **detector of poisoning attacks** is also supported:
35-
* Detector based on activations analysis
37+
* Detector based on activations analysis ([Chen et al., 2018](https://arxiv.org/abs/1811.03728))
3638

3739
## Setup
3840

art/attacks/carlini.py

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77

8+
from art import NUMPY_DTYPE
89
from art.attacks.attack import Attack
910
from art.utils import get_labels_np_array
1011

@@ -13,10 +14,10 @@
1314

1415
class CarliniL2Method(Attack):
1516
"""
16-
The L_2 optimized attack of Carlini and Wagner (2016). This attack is the most efficient and should be used as the
17-
primary attack to evaluate potential defences (wrt the L_0 and L_inf attacks). This implementation is inspired by
18-
the one in Cleverhans, which reproduces the authors' original code (https://github.com/carlini/nn_robust_attacks).
19-
Paper link: https://arxiv.org/pdf/1608.04644.pdf
17+
The L_2 optimized attack of Carlini and Wagner (2016). This attack is among the most effective and should be used
18+
among the primary attacks to evaluate potential defences. A major difference wrt to the original implementation
19+
(https://github.com/carlini/nn_robust_attacks) is that we use line search in the optimization of the attack
20+
objective. Paper link: https://arxiv.org/pdf/1608.04644.pdf
2021
"""
2122
attack_params = Attack.attack_params + ['confidence', 'targeted', 'learning_rate', 'max_iter',
2223
'binary_search_steps', 'initial_const', 'max_halving', 'max_doubling']
@@ -83,10 +84,10 @@ def _loss(self, x, x_adv, target, c):
8384
:return: A tuple holding the current logits, l2 distance and overall loss.
8485
:rtype: `(float, float, float)`
8586
"""
86-
l2dist = np.sum(np.square(x-x_adv))
87-
z = self.classifier.predict(np.array([x_adv]), logits=True)[0]
87+
l2dist = np.sum(np.square(x - x_adv))
88+
z = self.classifier.predict(np.array([x_adv], dtype=NUMPY_DTYPE), logits=True)[0]
8889
z_target = np.sum(z * target)
89-
z_other = np.max(z * (1 - target) + (np.min(z)-1)*target)
90+
z_other = np.max(z * (1 - target) + (np.min(z) - 1) * target)
9091

9192
# The following differs from the exact definition given in Carlini and Wagner (2016). There (page 9, left
9293
# column, last equation), the maximum is taken over Z_other - Z_target (or Z_target - Z_other respectively)
@@ -127,20 +128,21 @@ def _loss_gradient(self, z, target, x, x_adv, x_adv_tanh, c, clip_min, clip_max)
127128
:type target: `np.ndarray`
128129
"""
129130
if self.targeted:
130-
i_sub, i_add = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z)-1)*target)
131+
i_sub, i_add = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z) - 1) * target)
131132
else:
132-
i_add, i_sub = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z)-1)*target)
133+
i_add, i_sub = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z) - 1) * target)
133134

134-
loss_gradient = self.classifier.class_gradient(np.array([x_adv]), label=i_add, logits=True)[0]
135-
loss_gradient -= self.classifier.class_gradient(np.array([x_adv]), label=i_sub, logits=True)[0]
135+
loss_gradient = self.classifier.class_gradient(np.array([x_adv], dtype=NUMPY_DTYPE), label=i_add,
136+
logits=True)[0]
137+
loss_gradient -= self.classifier.class_gradient(np.array([x_adv], dtype=NUMPY_DTYPE), label=i_sub,
138+
logits=True)[0]
136139
loss_gradient *= c
137-
loss_gradient += 2*(x_adv - x)
140+
loss_gradient += 2 * (x_adv - x)
138141
loss_gradient *= (clip_max - clip_min)
139-
loss_gradient *= (1-np.square(np.tanh(x_adv_tanh)))/(2*self._tanh_smoother)
142+
loss_gradient *= (1 - np.square(np.tanh(x_adv_tanh))) / (2 * self._tanh_smoother)
140143

141144
return loss_gradient[0]
142-
143-
145+
144146
def _original_to_tanh(self, x_original, clip_min, clip_max):
145147
"""
146148
Transform input from original to tanh space.
@@ -191,7 +193,7 @@ def generate(self, x, **kwargs):
191193
:return: An array holding the adversarial examples.
192194
:rtype: `np.ndarray`
193195
"""
194-
x_adv = x.copy()
196+
x_adv = x.astype(NUMPY_DTYPE)
195197
(clip_min, clip_max) = self.classifier.clip_values
196198

197199
# Parse and save attack-specific parameters
@@ -207,7 +209,8 @@ def generate(self, x, **kwargs):
207209
if y is None:
208210
y = get_labels_np_array(self.classifier.predict(x, logits=False))
209211

210-
for j, (ex, target) in enumerate(zip(x_adv, y)):
212+
for j, (ex, target) in enumerate(zip(x_adv, y)):
213+
logger.debug('Processing sample %i out of %i', j, x_adv.shape[0])
211214
image = ex.copy()
212215

213216
# The optimization is performed in tanh space to keep the
@@ -221,63 +224,98 @@ def generate(self, x, **kwargs):
221224

222225
# Initialize placeholders for best l2 distance and attack found so far
223226
best_l2dist = sys.float_info.max
224-
best_adv_image = image
225-
lr = self.learning_rate
227+
best_adv_image = image
226228

227-
for _ in range(self.binary_search_steps):
229+
for bss in range(self.binary_search_steps):
230+
lr = self.learning_rate
231+
logger.debug('Binary search step %i out of %i (c==%f)', bss, self.binary_search_steps, c)
228232

229233
# Initialize perturbation in tanh space:
230-
perturbation_tanh = np.zeros(image_tanh.shape)
231234
adv_image = image
232235
adv_image_tanh = image_tanh
233236
z, l2dist, loss = self._loss(image, adv_image, target, c)
234-
attack_success = (loss-l2dist <= 0)
237+
attack_success = (loss - l2dist <= 0)
238+
overall_attack_success = attack_success
235239

236-
for it in range(self.max_iter):
240+
for it in range(self.max_iter):
241+
logger.debug('Iteration step %i out of %i', it, self.max_iter)
242+
logger.debug('Total Loss: %f', loss)
243+
logger.debug('L2Dist: %f', l2dist)
244+
logger.debug('Margin Loss: %f', loss-l2dist)
245+
237246
if attack_success:
238-
break
247+
logger.debug('Margin Loss <= 0 --> Attack Success!')
248+
if l2dist < best_l2dist:
249+
logger.debug('New best L2Dist: %f (previous=%f)', l2dist, best_l2dist)
250+
best_l2dist = l2dist
251+
best_adv_image = adv_image
239252

240253
# compute gradient:
254+
logger.debug('Compute loss gradient')
241255
perturbation_tanh = -self._loss_gradient(z, target, image, adv_image, adv_image_tanh,
242256
c, clip_min, clip_max)
243257

244258
# perform line search to optimize perturbation
245259
# first, halve the learning rate until perturbation actually decreases the loss:
246260
prev_loss = loss
261+
best_loss = loss
262+
best_lr = 0
263+
247264
halving = 0
248-
while loss >= prev_loss and loss-l2dist > 0 and halving < self.max_halving:
249-
new_adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
265+
while loss >= prev_loss and halving < self.max_halving:
266+
logger.debug('Apply gradient with learning rate %f (halving=%i)', lr, halving)
267+
new_adv_image_tanh = adv_image_tanh + lr * perturbation_tanh
250268
new_adv_image = self._tanh_to_original(new_adv_image_tanh, clip_min, clip_max)
251-
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
269+
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
270+
logger.debug('New Total Loss: %f', loss)
271+
logger.debug('New L2Dist: %f', l2dist)
272+
logger.debug('New Margin Loss: %f', loss-l2dist)
273+
if loss < best_loss:
274+
best_loss = loss
275+
best_lr = lr
252276
lr /= 2
253277
halving += 1
254278
lr *= 2
255279

256280
# if no halving was actually required, double the learning rate as long as this
257281
# decreases the loss:
258-
if halving == 1:
282+
if halving == 1 and loss <= prev_loss:
259283
doubling = 0
260284
while loss <= prev_loss and doubling < self.max_doubling:
261285
prev_loss = loss
262286
lr *= 2
287+
logger.debug('Apply gradient with learning rate %f (doubling=%i)', lr, doubling)
263288
doubling += 1
264-
new_adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
289+
new_adv_image_tanh = adv_image_tanh + lr * perturbation_tanh
265290
new_adv_image = self._tanh_to_original(new_adv_image_tanh, clip_min, clip_max)
266-
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
291+
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
292+
logger.debug('New Total Loss: %f', loss)
293+
logger.debug('New L2Dist: %f', l2dist)
294+
logger.debug('New Margin Loss: %f', loss-l2dist)
295+
if loss < best_loss:
296+
best_loss = loss
297+
best_lr = lr
267298
lr /= 2
268299

269-
# apply the optimal learning rate that was found and update the loss:
270-
adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
271-
adv_image = self._tanh_to_original(adv_image_tanh, clip_min, clip_max)
300+
if best_lr >0:
301+
logger.debug('Finally apply gradient with learning rate %f', best_lr)
302+
# apply the optimal learning rate that was found and update the loss:
303+
adv_image_tanh = adv_image_tanh + best_lr * perturbation_tanh
304+
adv_image = self._tanh_to_original(adv_image_tanh, clip_min, clip_max)
305+
272306
z, l2dist, loss = self._loss(image, adv_image, target, c)
273-
attack_success = (loss-l2dist <= 0)
307+
attack_success = (loss - l2dist <= 0)
308+
overall_attack_success = overall_attack_success or attack_success
274309

275310
# Update depending on attack success:
276311
if attack_success:
312+
logger.debug('Margin Loss <= 0 --> Attack Success!')
277313
if l2dist < best_l2dist:
314+
logger.debug('New best L2Dist: %f (previous=%f)', l2dist, best_l2dist)
278315
best_l2dist = l2dist
279-
best_adv_image = adv_image
280-
316+
best_adv_image = adv_image
317+
318+
if overall_attack_success:
281319
c_double = False
282320
c = (c_lower_bound + c) / 2
283321
else:
@@ -300,7 +338,7 @@ def generate(self, x, **kwargs):
300338
else:
301339
preds = np.argmax(self.classifier.predict(x), axis=1)
302340
rate = np.sum(adv_preds != preds) / x_adv.shape[0]
303-
logger.info('Success rate of C&W attack: %.2f%%', rate)
341+
logger.info('Success rate of C&W attack: %.2f%%', 100*rate)
304342

305343
return x_adv
306344

art/attacks/deepfool.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,26 @@ def generate(self, x, **kwargs):
4747
assert self.set_params(**kwargs)
4848
clip_min, clip_max = self.classifier.clip_values
4949
x_adv = x.copy()
50+
preds = self.classifier.predict(x, logits=True)
5051

5152
# Pick a small scalar to avoid division by 0
5253
tol = 10e-8
5354

5455
for j, val in enumerate(x_adv):
5556
xj = val[None, ...]
56-
57-
# TODO move prediction outside of for loop; add batching if `x` is too large?
58-
f = self.classifier.predict(xj, logits=True)[0]
57+
f = preds[j]
5958
grd = self.classifier.class_gradient(xj, logits=True)[0]
6059
fk_hat = np.argmax(f)
61-
fk_i_hat = fk_hat
62-
nb_iter = 0
6360

64-
while fk_i_hat == fk_hat and nb_iter < self.max_iter:
61+
for _ in range(self.max_iter):
6562
grad_diff = grd - grd[fk_hat]
6663
f_diff = f - f[fk_hat]
6764

68-
# Masking true label
69-
mask = [0] * self.classifier.nb_classes
70-
mask[fk_hat] = 1
65+
# Choose coordinate and compute perturbation
7166
norm = np.linalg.norm(grad_diff.reshape(self.classifier.nb_classes, -1), axis=1) + tol
72-
value = np.ma.array(np.abs(f_diff) / norm, mask=mask)
73-
74-
l = value.argmin(fill_value=np.inf)
67+
value = np.abs(f_diff) / norm
68+
value[fk_hat] = np.inf
69+
l = np.argmin(value)
7570
r = (abs(f_diff[l]) / (pow(np.linalg.norm(grad_diff[l]), 2) + tol)) * grad_diff[l]
7671

7772
# Add perturbation and clip result
@@ -82,11 +77,14 @@ def generate(self, x, **kwargs):
8277
grd = self.classifier.class_gradient(xj, logits=True)[0]
8378
fk_i_hat = np.argmax(f)
8479

85-
nb_iter += 1
80+
# Stop if misclassification has been achieved
81+
if fk_i_hat != fk_hat:
82+
break
8683

84+
# Apply overshoot parameter
8785
x_adv[j] = np.clip(x[j] + (1 + self.epsilon) * (xj[0] - x[j]), clip_min, clip_max)
8886

89-
preds = np.argmax(self.classifier.predict(x), axis=1)
87+
preds = np.argmax(preds, axis=1)
9088
preds_adv = np.argmax(self.classifier.predict(x_adv), axis=1)
9189
logger.info('Success rate of DeepFool attack: %.2f%%', (np.sum(preds != preds_adv) / x.shape[0]))
9290

0 commit comments

Comments
 (0)