Skip to content

Commit 1d566db

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Correct parameter setting in defenses
1 parent ee0264d commit 1d566db

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

src/defences/feature_squeezing.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

33
import numpy as np
4-
from tensorflow import rint
4+
import tensorflow as tf
55

66
from src.defences.preprocessor import Preprocessor
77

@@ -21,7 +21,7 @@ def __init__(self, bit_depth=8):
2121
self.is_fitted = True
2222
self.set_params(bit_depth=bit_depth)
2323

24-
def __call__(self, x_val, bit_depth=8):
24+
def __call__(self, x_val, bit_depth=None):
2525
"""
2626
Apply feature squeezing to sample x_val.
2727
@@ -30,16 +30,17 @@ def __call__(self, x_val, bit_depth=8):
3030
:return: Squeezed sample
3131
:rtype: np.ndarray
3232
"""
33-
self.set_params(bit_depth=bit_depth)
33+
if bit_depth is not None:
34+
self.set_params(bit_depth=bit_depth)
3435

35-
max_value = int(2 ** bit_depth - 1)
36+
max_value = np.rint(2 ** self.bit_depth - 1)
3637
return np.rint(x_val * max_value) / max_value
3738

3839
def fit(self, x_val, y_val=None, **kwargs):
3940
"""No parameters to learn for this method; do nothing."""
4041
pass
4142

42-
def _tf_predict(self, x, bit_depth=8):
43+
def _tf_predict(self, x, bit_depth=None):
4344
"""
4445
Apply feature squeezing on tf.Tensor.
4546
@@ -48,10 +49,11 @@ def _tf_predict(self, x, bit_depth=8):
4849
:return: Squeezed sample
4950
:rtype: tf.Tensor
5051
"""
51-
self.set_params(bit_depth=bit_depth)
52+
if bit_depth is not None:
53+
self.set_params(bit_depth=bit_depth)
5254

53-
max_value = int(2 ** bit_depth - 1)
54-
x = rint(x * max_value) / max_value
55+
max_value = int(2 ** self.bit_depth - 1)
56+
x = tf.rint(x * max_value) / max_value
5557
return x
5658

5759
def set_params(self, **kwargs):

src/defences/spatial_smoothing.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, window_size=3):
2020
self.is_fitted = True
2121
self.set_params(window_size=window_size)
2222

23-
def __call__(self, x_val, window_size=3):
23+
def __call__(self, x_val, window_size=None):
2424
"""
2525
Apply local spatial smoothing to sample x_val.
2626
:param x_val: (np.ndarray) Sample to smooth. `x_val` is supposed to
@@ -29,8 +29,10 @@ def __call__(self, x_val, window_size=3):
2929
:return: Smoothed sample
3030
:rtype: np.ndarray
3131
"""
32-
self.set_params(window_size=window_size)
33-
size = (1, window_size, window_size, 1)
32+
if window_size is not None:
33+
self.set_params(window_size=window_size)
34+
35+
size = (1, self.window_size, self.window_size, 1)
3436
result = ndimage.filters.median_filter(x_val, size=size, mode="reflect")
3537

3638
return result
@@ -55,9 +57,3 @@ def set_params(self, **kwargs):
5557
raise ValueError("Sliding window size must be a positive integer")
5658

5759
return True
58-
59-
60-
61-
62-
63-

src/defences/spatial_smoothing_unittest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def test_fix(self):
2525
# Start to test
2626
preprocess = SpatialSmoothing()
2727
smoothed_x = preprocess(x, 3)
28-
self.assertTrue((smoothed_x==np.array(
28+
self.assertTrue((smoothed_x == np.array(
2929
[[[[2], [3], [3]], [[4], [5], [6]], [[5], [6], [6]]]])).all())
3030

3131
smoothed_x = preprocess(x, 1)
32-
self.assertTrue((smoothed_x==x).all())
32+
self.assertTrue((smoothed_x == x).all())
3333

3434
smoothed_x = preprocess(x, 2)
35-
self.assertTrue((smoothed_x==np.array(
35+
self.assertTrue((smoothed_x == np.array(
3636
[[[[1], [2], [3]], [[7], [7], [8]], [[7], [7], [8]]]])).all())
3737

3838

0 commit comments

Comments
 (0)