Skip to content

Commit ee3b7ac

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Merge branch 'master' into release
Changes: - new defenses: JPEG compression, total variance minimization - optimization of NewtonFool attack to run batches - small optimizations/ bug fix in Carlini attack - small optimizations in DeepFool - changes to compute class gradients for a list of targets for all classifiers - minor changes in visualization module and utils - updates in docs, examples and notebooks Conflicts: art/defences/__init__.py art/utils.py requirements.txt
2 parents debd3d3 + 40f435c commit ee3b7ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1777
-478
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@ matrix:
99
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.6.0
1010
- python: 2.7
1111
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.7.0
12+
- python: 2.7
13+
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.10.0
1214
- python: 3.5
1315
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.5.0
1416
- python: 3.5
1517
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.6.0
1618
- python: 3.5
1719
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.7.0
20+
- python: 3.5
21+
env: KERAS_BACKEND=tensorflow TENSORFLOW_V=1.10.0
1822
exclude:
1923
- env:
2024

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ The following **defence** methods are also supported:
2626
* Virtual adversarial training ([Miyato et al., 2015](https://arxiv.org/abs/1507.00677))
2727
* Gaussian data augmentation ([Zantedeschi et al., 2017](https://arxiv.org/abs/1707.06728))
2828
* Thermometer encoding ([Buckman et al., 2018](https://openreview.net/forum?id=S18Su--CW))
29+
* Total variance minimization ([Guo et al., 2018](https://openreview.net/forum?id=SyJ7ClWCb))
30+
* JPEG compression ([Dziugaite et al., 2016](https://arxiv.org/abs/1608.00853))
2931

3032
ART also implements **detection** methods of adversarial samples:
3133
* Basic detector based on inputs
3234
* Detector trained on the activations of a specific layer
3335

3436
The following **detector of poisoning attacks** is also supported:
35-
* Detector based on activations analysis
37+
* Detector based on activations analysis ([Chen et al., 2018](https://arxiv.org/abs/1811.03728))
3638

3739
## Setup
3840

art/attacks/carlini.py

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424

25+
from art import NUMPY_DTYPE
2526
from art.attacks.attack import Attack
2627
from art.utils import get_labels_np_array
2728

@@ -30,10 +31,10 @@
3031

3132
class CarliniL2Method(Attack):
3233
"""
33-
The L_2 optimized attack of Carlini and Wagner (2016). This attack is the most efficient and should be used as the
34-
primary attack to evaluate potential defences (wrt the L_0 and L_inf attacks). This implementation is inspired by
35-
the one in Cleverhans, which reproduces the authors' original code (https://github.com/carlini/nn_robust_attacks).
36-
Paper link: https://arxiv.org/pdf/1608.04644.pdf
34+
The L_2 optimized attack of Carlini and Wagner (2016). This attack is among the most effective and should be used
35+
among the primary attacks to evaluate potential defences. A major difference wrt to the original implementation
36+
(https://github.com/carlini/nn_robust_attacks) is that we use line search in the optimization of the attack
37+
objective. Paper link: https://arxiv.org/pdf/1608.04644.pdf
3738
"""
3839
attack_params = Attack.attack_params + ['confidence', 'targeted', 'learning_rate', 'max_iter',
3940
'binary_search_steps', 'initial_const', 'max_halving', 'max_doubling']
@@ -100,10 +101,10 @@ def _loss(self, x, x_adv, target, c):
100101
:return: A tuple holding the current logits, l2 distance and overall loss.
101102
:rtype: `(float, float, float)`
102103
"""
103-
l2dist = np.sum(np.square(x-x_adv))
104-
z = self.classifier.predict(np.array([x_adv]), logits=True)[0]
104+
l2dist = np.sum(np.square(x - x_adv))
105+
z = self.classifier.predict(np.array([x_adv], dtype=NUMPY_DTYPE), logits=True)[0]
105106
z_target = np.sum(z * target)
106-
z_other = np.max(z * (1 - target) + (np.min(z)-1)*target)
107+
z_other = np.max(z * (1 - target) + (np.min(z) - 1) * target)
107108

108109
# The following differs from the exact definition given in Carlini and Wagner (2016). There (page 9, left
109110
# column, last equation), the maximum is taken over Z_other - Z_target (or Z_target - Z_other respectively)
@@ -144,20 +145,21 @@ def _loss_gradient(self, z, target, x, x_adv, x_adv_tanh, c, clip_min, clip_max)
144145
:type target: `np.ndarray`
145146
"""
146147
if self.targeted:
147-
i_sub, i_add = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z)-1)*target)
148+
i_sub, i_add = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z) - 1) * target)
148149
else:
149-
i_add, i_sub = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z)-1)*target)
150+
i_add, i_sub = np.argmax(target), np.argmax(z * (1 - target) + (np.min(z) - 1) * target)
150151

151-
loss_gradient = self.classifier.class_gradient(np.array([x_adv]), label=i_add, logits=True)[0]
152-
loss_gradient -= self.classifier.class_gradient(np.array([x_adv]), label=i_sub, logits=True)[0]
152+
loss_gradient = self.classifier.class_gradient(np.array([x_adv], dtype=NUMPY_DTYPE), label=i_add,
153+
logits=True)[0]
154+
loss_gradient -= self.classifier.class_gradient(np.array([x_adv], dtype=NUMPY_DTYPE), label=i_sub,
155+
logits=True)[0]
153156
loss_gradient *= c
154-
loss_gradient += 2*(x_adv - x)
157+
loss_gradient += 2 * (x_adv - x)
155158
loss_gradient *= (clip_max - clip_min)
156-
loss_gradient *= (1-np.square(np.tanh(x_adv_tanh)))/(2*self._tanh_smoother)
159+
loss_gradient *= (1 - np.square(np.tanh(x_adv_tanh))) / (2 * self._tanh_smoother)
157160

158161
return loss_gradient[0]
159-
160-
162+
161163
def _original_to_tanh(self, x_original, clip_min, clip_max):
162164
"""
163165
Transform input from original to tanh space.
@@ -208,7 +210,7 @@ def generate(self, x, **kwargs):
208210
:return: An array holding the adversarial examples.
209211
:rtype: `np.ndarray`
210212
"""
211-
x_adv = x.copy()
213+
x_adv = x.astype(NUMPY_DTYPE)
212214
(clip_min, clip_max) = self.classifier.clip_values
213215

214216
# Parse and save attack-specific parameters
@@ -224,7 +226,8 @@ def generate(self, x, **kwargs):
224226
if y is None:
225227
y = get_labels_np_array(self.classifier.predict(x, logits=False))
226228

227-
for j, (ex, target) in enumerate(zip(x_adv, y)):
229+
for j, (ex, target) in enumerate(zip(x_adv, y)):
230+
logger.debug('Processing sample %i out of %i', j, x_adv.shape[0])
228231
image = ex.copy()
229232

230233
# The optimization is performed in tanh space to keep the
@@ -238,63 +241,98 @@ def generate(self, x, **kwargs):
238241

239242
# Initialize placeholders for best l2 distance and attack found so far
240243
best_l2dist = sys.float_info.max
241-
best_adv_image = image
242-
lr = self.learning_rate
244+
best_adv_image = image
243245

244-
for _ in range(self.binary_search_steps):
246+
for bss in range(self.binary_search_steps):
247+
lr = self.learning_rate
248+
logger.debug('Binary search step %i out of %i (c==%f)', bss, self.binary_search_steps, c)
245249

246250
# Initialize perturbation in tanh space:
247-
perturbation_tanh = np.zeros(image_tanh.shape)
248251
adv_image = image
249252
adv_image_tanh = image_tanh
250253
z, l2dist, loss = self._loss(image, adv_image, target, c)
251-
attack_success = (loss-l2dist <= 0)
254+
attack_success = (loss - l2dist <= 0)
255+
overall_attack_success = attack_success
252256

253-
for it in range(self.max_iter):
257+
for it in range(self.max_iter):
258+
logger.debug('Iteration step %i out of %i', it, self.max_iter)
259+
logger.debug('Total Loss: %f', loss)
260+
logger.debug('L2Dist: %f', l2dist)
261+
logger.debug('Margin Loss: %f', loss-l2dist)
262+
254263
if attack_success:
255-
break
264+
logger.debug('Margin Loss <= 0 --> Attack Success!')
265+
if l2dist < best_l2dist:
266+
logger.debug('New best L2Dist: %f (previous=%f)', l2dist, best_l2dist)
267+
best_l2dist = l2dist
268+
best_adv_image = adv_image
256269

257270
# compute gradient:
271+
logger.debug('Compute loss gradient')
258272
perturbation_tanh = -self._loss_gradient(z, target, image, adv_image, adv_image_tanh,
259273
c, clip_min, clip_max)
260274

261275
# perform line search to optimize perturbation
262276
# first, halve the learning rate until perturbation actually decreases the loss:
263277
prev_loss = loss
278+
best_loss = loss
279+
best_lr = 0
280+
264281
halving = 0
265-
while loss >= prev_loss and loss-l2dist > 0 and halving < self.max_halving:
266-
new_adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
282+
while loss >= prev_loss and halving < self.max_halving:
283+
logger.debug('Apply gradient with learning rate %f (halving=%i)', lr, halving)
284+
new_adv_image_tanh = adv_image_tanh + lr * perturbation_tanh
267285
new_adv_image = self._tanh_to_original(new_adv_image_tanh, clip_min, clip_max)
268-
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
286+
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
287+
logger.debug('New Total Loss: %f', loss)
288+
logger.debug('New L2Dist: %f', l2dist)
289+
logger.debug('New Margin Loss: %f', loss-l2dist)
290+
if loss < best_loss:
291+
best_loss = loss
292+
best_lr = lr
269293
lr /= 2
270294
halving += 1
271295
lr *= 2
272296

273297
# if no halving was actually required, double the learning rate as long as this
274298
# decreases the loss:
275-
if halving == 1:
299+
if halving == 1 and loss <= prev_loss:
276300
doubling = 0
277301
while loss <= prev_loss and doubling < self.max_doubling:
278302
prev_loss = loss
279303
lr *= 2
304+
logger.debug('Apply gradient with learning rate %f (doubling=%i)', lr, doubling)
280305
doubling += 1
281-
new_adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
306+
new_adv_image_tanh = adv_image_tanh + lr * perturbation_tanh
282307
new_adv_image = self._tanh_to_original(new_adv_image_tanh, clip_min, clip_max)
283-
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
308+
_, l2dist, loss = self._loss(image, new_adv_image, target, c)
309+
logger.debug('New Total Loss: %f', loss)
310+
logger.debug('New L2Dist: %f', l2dist)
311+
logger.debug('New Margin Loss: %f', loss-l2dist)
312+
if loss < best_loss:
313+
best_loss = loss
314+
best_lr = lr
284315
lr /= 2
285316

286-
# apply the optimal learning rate that was found and update the loss:
287-
adv_image_tanh = adv_image_tanh + lr*perturbation_tanh
288-
adv_image = self._tanh_to_original(adv_image_tanh, clip_min, clip_max)
317+
if best_lr >0:
318+
logger.debug('Finally apply gradient with learning rate %f', best_lr)
319+
# apply the optimal learning rate that was found and update the loss:
320+
adv_image_tanh = adv_image_tanh + best_lr * perturbation_tanh
321+
adv_image = self._tanh_to_original(adv_image_tanh, clip_min, clip_max)
322+
289323
z, l2dist, loss = self._loss(image, adv_image, target, c)
290-
attack_success = (loss-l2dist <= 0)
324+
attack_success = (loss - l2dist <= 0)
325+
overall_attack_success = overall_attack_success or attack_success
291326

292327
# Update depending on attack success:
293328
if attack_success:
329+
logger.debug('Margin Loss <= 0 --> Attack Success!')
294330
if l2dist < best_l2dist:
331+
logger.debug('New best L2Dist: %f (previous=%f)', l2dist, best_l2dist)
295332
best_l2dist = l2dist
296-
best_adv_image = adv_image
297-
333+
best_adv_image = adv_image
334+
335+
if overall_attack_success:
298336
c_double = False
299337
c = (c_lower_bound + c) / 2
300338
else:
@@ -317,7 +355,7 @@ def generate(self, x, **kwargs):
317355
else:
318356
preds = np.argmax(self.classifier.predict(x), axis=1)
319357
rate = np.sum(adv_preds != preds) / x_adv.shape[0]
320-
logger.info('Success rate of C&W attack: %.2f%%', rate)
358+
logger.info('Success rate of C&W attack: %.2f%%', 100*rate)
321359

322360
return x_adv
323361

art/attacks/deepfool.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,31 +64,26 @@ def generate(self, x, **kwargs):
6464
assert self.set_params(**kwargs)
6565
clip_min, clip_max = self.classifier.clip_values
6666
x_adv = x.copy()
67+
preds = self.classifier.predict(x, logits=True)
6768

6869
# Pick a small scalar to avoid division by 0
6970
tol = 10e-8
7071

7172
for j, val in enumerate(x_adv):
7273
xj = val[None, ...]
73-
74-
# TODO move prediction outside of for loop; add batching if `x` is too large?
75-
f = self.classifier.predict(xj, logits=True)[0]
74+
f = preds[j]
7675
grd = self.classifier.class_gradient(xj, logits=True)[0]
7776
fk_hat = np.argmax(f)
78-
fk_i_hat = fk_hat
79-
nb_iter = 0
8077

81-
while fk_i_hat == fk_hat and nb_iter < self.max_iter:
78+
for _ in range(self.max_iter):
8279
grad_diff = grd - grd[fk_hat]
8380
f_diff = f - f[fk_hat]
8481

85-
# Masking true label
86-
mask = [0] * self.classifier.nb_classes
87-
mask[fk_hat] = 1
82+
# Choose coordinate and compute perturbation
8883
norm = np.linalg.norm(grad_diff.reshape(self.classifier.nb_classes, -1), axis=1) + tol
89-
value = np.ma.array(np.abs(f_diff) / norm, mask=mask)
90-
91-
l = value.argmin(fill_value=np.inf)
84+
value = np.abs(f_diff) / norm
85+
value[fk_hat] = np.inf
86+
l = np.argmin(value)
9287
r = (abs(f_diff[l]) / (pow(np.linalg.norm(grad_diff[l]), 2) + tol)) * grad_diff[l]
9388

9489
# Add perturbation and clip result
@@ -99,11 +94,14 @@ def generate(self, x, **kwargs):
9994
grd = self.classifier.class_gradient(xj, logits=True)[0]
10095
fk_i_hat = np.argmax(f)
10196

102-
nb_iter += 1
97+
# Stop if misclassification has been achieved
98+
if fk_i_hat != fk_hat:
99+
break
103100

101+
# Apply overshoot parameter
104102
x_adv[j] = np.clip(x[j] + (1 + self.epsilon) * (xj[0] - x[j]), clip_min, clip_max)
105103

106-
preds = np.argmax(self.classifier.predict(x), axis=1)
104+
preds = np.argmax(preds, axis=1)
107105
preds_adv = np.argmax(self.classifier.predict(x_adv), axis=1)
108106
logger.info('Success rate of DeepFool attack: %.2f%%', (np.sum(preds != preds_adv) / x.shape[0]))
109107

0 commit comments

Comments
 (0)