Skip to content

Commit a9c7531

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Merge Clever fix, Gaussian aug and Keras pred batching
2 parents dcb2266 + c3a7ac5 commit a9c7531

File tree

8 files changed

+156
-18
lines changed

8 files changed

+156
-18
lines changed

art/classifiers/keras.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,16 @@ def predict(self, inputs, logits=False):
114114
# Apply defences
115115
inputs = self._apply_defences_predict(inputs)
116116

117-
preds = self._preds([inputs])[0]
118-
if not logits:
119-
exp = np.exp(preds - np.max(preds, axis=1, keepdims=True))
120-
preds = exp / np.sum(exp, axis=1, keepdims=True)
117+
# Run predictions with batching
118+
batch_size = 512
119+
preds = np.zeros((inputs.shape[0], self.nb_classes), dtype=np.float32)
120+
for b in range(inputs.shape[0] // batch_size + 1):
121+
begin, end = b * batch_size, min((b + 1) * batch_size, inputs.shape[0])
122+
preds[begin:end] = self._preds([inputs[begin:end]])[0]
123+
124+
if not logits:
125+
exp = np.exp(preds[begin:end] - np.max(preds[begin:end], axis=1, keepdims=True))
126+
preds[begin:end] = exp / np.sum(exp, axis=1, keepdims=True)
121127

122128
return preds
123129

@@ -141,13 +147,13 @@ def fit(self, inputs, outputs, batch_size=128, nb_epochs=20):
141147
# Apply defences
142148
inputs, outputs = self._apply_defences_fit(inputs, outputs)
143149

144-
gen = generator(inputs, outputs, batch_size)
150+
gen = generator_fit(inputs, outputs, batch_size)
145151
self._model.fit_generator(gen, steps_per_epoch=inputs.shape[0] / batch_size, epochs=nb_epochs)
146152

147153

148-
def generator(data, labels, batch_size=128):
154+
def generator_fit(data, labels, batch_size=128):
149155
"""
150-
Minimal data generator for batching large datasets.
156+
Minimal data generator for randomly batching large datasets.
151157
152158
:param data: The data sample to batch.
153159
:type data: `np.ndarray`
@@ -160,4 +166,4 @@ def generator(data, labels, batch_size=128):
160166
"""
161167
while True:
162168
indices = np.random.randint(data.shape[0], size=batch_size)
163-
yield data[indices], labels[indices]
169+
yield data[indices], labels[indices]

art/defences/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
"""
44
from art.defences.adversarial_trainer import AdversarialTrainer
55
from art.defences.feature_squeezing import FeatureSqueezing
6+
from art.defences.gaussian_augmentation import GaussianAugmentation
67
from art.defences.label_smoothing import LabelSmoothing
78
from art.defences.spatial_smoothing import SpatialSmoothing
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
from art.defences.preprocessor import Preprocessor
4+
5+
6+
class GaussianAugmentation(Preprocessor):
7+
"""
8+
Perform Gaussian augmentation on a dataset.
9+
"""
10+
params = ['sigma', 'ratio']
11+
12+
def __init__(self, sigma=1., ratio=1.):
13+
"""
14+
Initialize a Gaussian augmentation object.
15+
16+
:param sigma: Standard deviation of Gaussian noise to be added.
17+
:type sigma: `float`
18+
:param ratio: Percentage of data augmentation. E.g. for a rate of 1, the size of the dataset will double.
19+
:type ratio: `float`
20+
"""
21+
super(GaussianAugmentation, self).__init__()
22+
self._is_fitted = True
23+
self.set_params(sigma=sigma, ratio=ratio)
24+
25+
def __call__(self, x, y=None, sigma=None, ratio=None):
26+
"""
27+
Augment the sample `(x, y)` with Gaussian noise. The result is an extended dataset containing the original
28+
sample, as well as the newly created noisy samples.
29+
30+
:param x: Sample to augment with shape `(batch_size, width, height, depth)`.
31+
:type x: `np.ndarray`
32+
:param y: Labels for the sample. If this argument is provided, it will be augmented with the corresponded
33+
original labels of each sample point.
34+
:param sigma: Standard deviation of Gaussian noise to be added.
35+
:type sigma: `float`
36+
:param ratio: Percentage of data augmentation. E.g. for a ratio of 1, the size of the dataset will double.
37+
:type ratio: `float`
38+
:return: The augmented dataset and (if provided) corresponding labels.
39+
:rtype:
40+
"""
41+
# Set params
42+
params = {}
43+
if sigma is not None:
44+
params['sigma'] = sigma
45+
46+
if ratio is not None:
47+
params['ratio'] = ratio
48+
49+
if params:
50+
self.set_params(**params)
51+
52+
# Select indices to augment
53+
import numpy as np
54+
size = int(x.shape[0] * self.ratio)
55+
indices = np.random.randint(0, x.shape[0], size=size)
56+
57+
# Generate noisy samples
58+
x_aug = np.random.normal(x[indices], scale=self.sigma, size=(size,) + x[indices].shape[1:])
59+
x_aug = np.vstack((x, x_aug))
60+
61+
if y is not None:
62+
y_aug = np.concatenate((y, y[indices]))
63+
return x_aug, y_aug
64+
else:
65+
return x_aug
66+
67+
def fit(self, x, y=None, **kwargs):
68+
"""
69+
No parameters to learn for this method; do nothing.
70+
"""
71+
pass
72+
73+
def set_params(self, **kwargs):
74+
"""
75+
Take in a dictionary of parameters and applies defense-specific checks before saving them as attributes.
76+
77+
:param sigma: Standard deviation of Gaussian noise to be added.
78+
:type sigma: `float`
79+
:param ratio: Percentage of data augmentation. E.g. for a ratio of 1, the size of the dataset will double.
80+
:type ratio: `float`
81+
"""
82+
# Save attack-specific parameters
83+
super(GaussianAugmentation, self).set_params(**kwargs)
84+
85+
if self.ratio <= 0:
86+
raise ValueError("The augmentation ratio must be positive.")
87+
88+
return True
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import unittest
4+
5+
import numpy as np
6+
7+
from art.defences.gaussian_augmentation import GaussianAugmentation
8+
9+
10+
class TestGaussianAugmentation(unittest.TestCase):
11+
def test_small_size(self):
12+
x = np.arange(15).reshape((5, 3))
13+
ga = GaussianAugmentation()
14+
new_x = ga(x, ratio=.4)
15+
self.assertTrue(new_x.shape == (7, 3))
16+
17+
def test_double_size(self):
18+
x = np.arange(12).reshape((4, 3))
19+
ga = GaussianAugmentation()
20+
new_x = ga(x)
21+
self.assertTrue(new_x.shape[0] == 2 * x.shape[0])
22+
23+
def test_multiple_size(self):
24+
x = np.arange(12).reshape((4, 3))
25+
ga = GaussianAugmentation(ratio=3.5)
26+
new_x = ga(x)
27+
self.assertTrue(int(4.5 * x.shape[0]) == new_x.shape[0])
28+
29+
def test_labels(self):
30+
x = np.arange(12).reshape((4, 3))
31+
y = np.arange(8).reshape((4, 2))
32+
33+
ga = GaussianAugmentation()
34+
new_x, new_y = ga(x, y)
35+
self.assertTrue(new_x.shape[0] == new_y.shape[0] == 8)
36+
self.assertTrue(new_x.shape[1:] == x.shape[1:])
37+
self.assertTrue(new_y.shape[1:] == y.shape[1:])

art/defences/spatial_smoothing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, window_size=3):
2424

2525
def __call__(self, x, window_size=None):
2626
"""
27-
Apply local spatial smoothing to sample `x_val`.
27+
Apply local spatial smoothing to sample `x`.
2828
2929
:param x: Sample to smooth with shape `(batch_size, width, height, depth)`.
3030
:type x: `np.ndarray`

art/metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,6 @@ def clever_t(classifier, x, target_class, n_b, n_s, r, norm, c_init=1, pool_fact
201201
if pool_factor < 1:
202202
raise ValueError("The pool_factor must be larger than 1")
203203

204-
# Change norm since q = p / (p-1)
205-
if norm == 1:
206-
norm = np.inf
207-
elif norm == np.inf:
208-
norm = 1
209-
elif norm != 2:
210-
raise ValueError("Norm {} not supported".format(norm))
211-
212204
# Some auxiliary vars
213205
grad_norm_set = []
214206
dim = reduce(lambda x_, y: x_ * y, x.shape, 1)
@@ -220,6 +212,14 @@ def clever_t(classifier, x, target_class, n_b, n_s, r, norm, c_init=1, pool_fact
220212
rand_pool += np.repeat(np.array([x]), pool_factor * n_s, 0)
221213
np.clip(rand_pool, classifier.clip_values[0], classifier.clip_values[1], out=rand_pool)
222214

215+
# Change norm since q = p / (p-1)
216+
if norm == 1:
217+
norm = np.inf
218+
elif norm == np.inf:
219+
norm = 1
220+
elif norm != 2:
221+
raise ValueError("Norm {} not supported".format(norm))
222+
223223
# Loop over n_b batches
224224
for i in range(n_b):
225225
# Random generation of data points

docs/modules/defences.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ Adversarial Training
2525
.. autoclass:: AdversarialTrainer
2626
:members:
2727
:special-members:
28+
29+
Gaussian Data Augmentation
30+
--------------------------
31+
.. autoclass:: GaussianAugmentation
32+
:members:
33+
:special-members:

docs/modules/detection.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
:mod:`art.detection`
2-
===================
2+
====================
33

44
Base Class
55
----------

0 commit comments

Comments
 (0)