Skip to content

Commit ee0264d

Browse files
MARIA NICOLAEGitHub Enterprise
authored andcommitted
Merge pull request #53 from M-N-Tran/dev
Implement the spatial smoothing defence
2 parents 74d04f8 + 1e6a660 commit ee0264d

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

src/defences/spatial_smoothing.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
from scipy import ndimage
4+
5+
from src.defences.preprocessor import Preprocessor
6+
7+
8+
class SpatialSmoothing(Preprocessor):
9+
"""
10+
Implement the local spatial smoothing defence approach.
11+
Defence method from https://arxiv.org/abs/1704.01155.
12+
"""
13+
params = ["window_size"]
14+
15+
def __init__(self, window_size=3):
16+
"""
17+
Create an instance of local spatial smoothing.
18+
:param window_size: (int) The size of the sliding window.
19+
"""
20+
self.is_fitted = True
21+
self.set_params(window_size=window_size)
22+
23+
def __call__(self, x_val, window_size=3):
24+
"""
25+
Apply local spatial smoothing to sample x_val.
26+
:param x_val: (np.ndarray) Sample to smooth. `x_val` is supposed to
27+
have shape (batch_size, width, height, depth).
28+
:param window_size: (int) The size of the sliding window.
29+
:return: Smoothed sample
30+
:rtype: np.ndarray
31+
"""
32+
self.set_params(window_size=window_size)
33+
size = (1, window_size, window_size, 1)
34+
result = ndimage.filters.median_filter(x_val, size=size, mode="reflect")
35+
36+
return result
37+
38+
def fit(self, x_val, y_val=None, **kwargs):
39+
"""
40+
No parameters to learn for this method; do nothing.
41+
"""
42+
pass
43+
44+
def set_params(self, **kwargs):
45+
"""
46+
Take in a dictionary of parameters and applies defense-specific checks
47+
before saving them as attributes.
48+
Defense-specific parameters:
49+
:param window_size: (int) The size of the sliding window.
50+
"""
51+
# Save attack-specific parameters
52+
super(SpatialSmoothing, self).set_params(**kwargs)
53+
54+
if type(self.window_size) is not int or self.window_size <= 0:
55+
raise ValueError("Sliding window size must be a positive integer")
56+
57+
return True
58+
59+
60+
61+
62+
63+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import unittest
4+
5+
import numpy as np
6+
7+
from src.defences.spatial_smoothing import SpatialSmoothing
8+
9+
10+
class TestLocalSpatialSmoothing(unittest.TestCase):
11+
def test_ones(self):
12+
m, n = 10, 2
13+
x = np.ones((1, m, n, 3))
14+
15+
# Start to test
16+
for window_size in range(1, 20):
17+
with self.subTest("Sliding window size = {}".format(window_size)):
18+
preprocess = SpatialSmoothing()
19+
smoothed_x = preprocess(x, window_size)
20+
self.assertTrue((smoothed_x == 1).all())
21+
22+
def test_fix(self):
23+
x = np.array([[[[1], [2], [3]], [[7], [8], [9]], [[4], [5], [6]]]])
24+
25+
# Start to test
26+
preprocess = SpatialSmoothing()
27+
smoothed_x = preprocess(x, 3)
28+
self.assertTrue((smoothed_x==np.array(
29+
[[[[2], [3], [3]], [[4], [5], [6]], [[5], [6], [6]]]])).all())
30+
31+
smoothed_x = preprocess(x, 1)
32+
self.assertTrue((smoothed_x==x).all())
33+
34+
smoothed_x = preprocess(x, 2)
35+
self.assertTrue((smoothed_x==np.array(
36+
[[[[1], [2], [3]], [[7], [7], [8]], [[7], [7], [8]]]])).all())
37+
38+
39+
if __name__ == '__main__':
40+
unittest.main()
41+
42+
43+
44+

0 commit comments

Comments
 (0)