Skip to content

Commit e8cd22a

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Add Gaussian augmentation in preprocessing
1 parent 12d77c6 commit e8cd22a

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
from art.defences.preprocessor import Preprocessor
4+
5+
6+
class GaussianAugmentation(Preprocessor):
7+
"""
8+
Perform Gaussian augmentation on a dataset.
9+
"""
10+
params = ['sigma', 'ratio']
11+
12+
def __init__(self, sigma=1., ratio=1.):
13+
"""
14+
Initialize a Gaussian augmentation object.
15+
16+
:param sigma: Standard deviation of Gaussian noise to be added.
17+
:type sigma: `float`
18+
:param ratio: Percentage of data augmentation. E.g. for a rate of 1, the size of the dataset will double.
19+
:type ratio: `float`
20+
"""
21+
super(GaussianAugmentation, self).__init__()
22+
self._is_fitted = True
23+
self.set_params(sigma=sigma, ratio=ratio)
24+
25+
def __call__(self, x, y=None, sigma=None, ratio=None):
26+
"""
27+
Augment the sample `(x, y)` with Gaussian noise. The result is an extended dataset containing the original
28+
sample, as well as the newly created noisy samples.
29+
30+
:param x: Sample to augment with shape `(batch_size, width, height, depth)`.
31+
:type x: `np.ndarray`
32+
:param y: Labels for the sample. If this argument is provided, it will be augmented with the corresponded
33+
original labels of each sample point.
34+
:param sigma: Standard deviation of Gaussian noise to be added.
35+
:type sigma: `float`
36+
:param ratio: Percentage of data augmentation. E.g. for a ratio of 1, the size of the dataset will double.
37+
:type ratio: `float`
38+
:return: The augmented dataset and (if provided) corresponding labels.
39+
:rtype:
40+
"""
41+
# Set params
42+
params = {}
43+
if sigma is not None:
44+
params['sigma'] = sigma
45+
46+
if ratio is not None:
47+
params['ratio'] = ratio
48+
49+
if params:
50+
self.set_params(**params)
51+
52+
# Select indices to augment
53+
import numpy as np
54+
size = int(x.shape[0] * self.ratio)
55+
indices = np.random.randint(0, x.shape[0], size=size)
56+
57+
# Generate noisy samples
58+
x_aug = np.random.normal(x[indices], scale=self.sigma, size=(size,) + x[indices].shape[1:])
59+
x_aug = np.vstack((x, x_aug))
60+
61+
if y is not None:
62+
y_aug = np.concatenate((y, y[indices]))
63+
return x_aug, y_aug
64+
else:
65+
return x_aug
66+
67+
def fit(self, x, y=None, **kwargs):
68+
"""
69+
No parameters to learn for this method; do nothing.
70+
"""
71+
pass
72+
73+
def set_params(self, **kwargs):
74+
"""
75+
Take in a dictionary of parameters and applies defense-specific checks before saving them as attributes.
76+
77+
:param sigma: Standard deviation of Gaussian noise to be added.
78+
:type sigma: `float`
79+
:param ratio: Percentage of data augmentation. E.g. for a ratio of 1, the size of the dataset will double.
80+
:type ratio: `float`
81+
"""
82+
# Save attack-specific parameters
83+
super(GaussianAugmentation, self).set_params(**kwargs)
84+
85+
if self.ratio <= 0:
86+
raise ValueError("The augmentation ratio must be positive.")
87+
88+
return True
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import unittest
4+
5+
import numpy as np
6+
7+
from art.defences.gaussian_augmentation import GaussianAugmentation
8+
9+
10+
class TestGaussianAugmentation(unittest.TestCase):
11+
def test_small_size(self):
12+
x = np.arange(15).reshape((5, 3))
13+
ga = GaussianAugmentation()
14+
new_x = ga(x, ratio=.4)
15+
self.assertTrue(new_x.shape == (7, 3))
16+
17+
def test_double_size(self):
18+
x = np.arange(12).reshape((4, 3))
19+
ga = GaussianAugmentation()
20+
new_x = ga(x)
21+
self.assertTrue(new_x.shape[0] == 2 * x.shape[0])
22+
23+
def test_multiple_size(self):
24+
x = np.arange(12).reshape((4, 3))
25+
ga = GaussianAugmentation(ratio=3.5)
26+
new_x = ga(x)
27+
self.assertTrue(int(4.5 * x.shape[0]) == new_x.shape[0])
28+
29+
def test_labels(self):
30+
x = np.arange(12).reshape((4, 3))
31+
y = np.arange(8).reshape((4, 2))
32+
33+
ga = GaussianAugmentation()
34+
new_x, new_y = ga(x, y)
35+
self.assertTrue(new_x.shape[0] == new_y.shape[0] == 8)
36+
self.assertTrue(new_x.shape[1:] == x.shape[1:])
37+
self.assertTrue(new_y.shape[1:] == y.shape[1:])

0 commit comments

Comments
 (0)