Skip to content

Commit a6717f6

Browse files
authored
Merge pull request #752 from davidslater/simplified-preprocessing
Simplified Preprocessor and Postprocessor Usage and Defaults
2 parents 01f05a8 + 2c49a75 commit a6717f6

24 files changed

+67
-527
lines changed

art/defences/postprocessor/class_labels.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,7 @@ def __init__(self, apply_fit: bool = False, apply_predict: bool = True) -> None:
3939
:param apply_fit: True if applied during fitting/training.
4040
:param apply_predict: True if applied during predicting.
4141
"""
42-
super().__init__()
43-
self._is_fitted = True
44-
self._apply_fit = apply_fit
45-
self._apply_predict = apply_predict
46-
47-
@property
48-
def apply_fit(self) -> bool:
49-
return self._apply_fit
50-
51-
@property
52-
def apply_predict(self) -> bool:
53-
return self._apply_predict
42+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
5443

5544
def __call__(self, preds: np.ndarray) -> np.ndarray:
5645
"""
@@ -67,9 +56,3 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
6756
class_labels[preds > 0.5] = 1
6857

6958
return class_labels
70-
71-
def fit(self, preds: np.ndarray, **kwargs) -> None:
72-
"""
73-
No parameters to learn for this method; do nothing.
74-
"""
75-
pass

art/defences/postprocessor/gaussian_noise.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,10 @@ def __init__(self, scale: float = 0.2, apply_fit: bool = False, apply_predict: b
4343
:param apply_fit: True if applied during fitting/training.
4444
:param apply_predict: True if applied during predicting.
4545
"""
46-
super().__init__()
47-
self._is_fitted = True
48-
self._apply_fit = apply_fit
49-
self._apply_predict = apply_predict
46+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
5047
self.scale = scale
5148
self._check_params()
5249

53-
@property
54-
def apply_fit(self) -> bool:
55-
return self._apply_fit
56-
57-
@property
58-
def apply_predict(self) -> bool:
59-
return self._apply_predict
60-
6150
def __call__(self, preds: np.ndarray) -> np.ndarray:
6251
"""
6352
Perform model postprocessing and return postprocessed output.
@@ -87,12 +76,6 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
8776

8877
return post_preds
8978

90-
def fit(self, preds: np.ndarray, **kwargs) -> None:
91-
"""
92-
No parameters to learn for this method; do nothing.
93-
"""
94-
pass
95-
9679
def _check_params(self) -> None:
9780
if self.scale <= 0:
9881
raise ValueError("Standard deviation must be positive.")

art/defences/postprocessor/high_confidence.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,10 @@ def __init__(self, cutoff: float = 0.25, apply_fit: bool = False, apply_predict:
4242
:param apply_fit: True if applied during fitting/training.
4343
:param apply_predict: True if applied during predicting.
4444
"""
45-
super().__init__()
46-
self._is_fitted = True
47-
self._apply_fit = apply_fit
48-
self._apply_predict = apply_predict
45+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
4946
self.cutoff = cutoff
5047
self._check_params()
5148

52-
@property
53-
def apply_fit(self) -> bool:
54-
return self._apply_fit
55-
56-
@property
57-
def apply_predict(self) -> bool:
58-
return self._apply_predict
59-
6049
def __call__(self, preds: np.ndarray) -> np.ndarray:
6150
"""
6251
Perform model postprocessing and return postprocessed output.
@@ -69,12 +58,6 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
6958

7059
return post_preds
7160

72-
def fit(self, preds: np.ndarray, **kwargs) -> None:
73-
"""
74-
No parameters to learn for this method; do nothing.
75-
"""
76-
pass
77-
7861
def _check_params(self) -> None:
7962
if self.cutoff <= 0:
8063
raise ValueError("Minimal value must be positive.")

art/defences/postprocessor/postprocessor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@ class Postprocessor(abc.ABC):
3434

3535
params: List[str] = []
3636

37-
def __init__(self) -> None:
37+
def __init__(self, is_fitted: bool = False, apply_fit: bool = True, apply_predict: bool = True) -> None:
3838
"""
3939
Create a postprocessing object.
40+
41+
Optionally, set attributes.
4042
"""
41-
self._is_fitted = False
43+
self._is_fitted = bool(is_fitted)
44+
self._apply_fit = bool(apply_fit)
45+
self._apply_predict = bool(apply_predict)
4246

4347
@property
4448
def is_fitted(self) -> bool:
@@ -50,24 +54,22 @@ def is_fitted(self) -> bool:
5054
return self._is_fitted
5155

5256
@property
53-
@abc.abstractmethod
5457
def apply_fit(self) -> bool:
5558
"""
5659
Property of the defence indicating if it should be applied at training time.
5760
5861
:return: `True` if the defence should be applied when fitting a model, `False` otherwise.
5962
"""
60-
raise NotImplementedError
63+
return self._apply_fit
6164

6265
@property
63-
@abc.abstractmethod
6466
def apply_predict(self) -> bool:
6567
"""
6668
Property of the defence indicating if it should be applied at test time.
6769
6870
:return: `True` if the defence should be applied at prediction time, `False` otherwise.
6971
"""
70-
raise NotImplementedError
72+
return self._apply_predict
7173

7274
@abc.abstractmethod
7375
def __call__(self, preds: np.ndarray) -> np.ndarray:
@@ -79,15 +81,14 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
7981
"""
8082
raise NotImplementedError
8183

82-
@abc.abstractmethod
8384
def fit(self, preds: np.ndarray, **kwargs) -> None:
8485
"""
8586
Fit the parameters of the postprocessor if it has any.
8687
8788
:param preds: Training set to fit the postprocessor.
8889
:param kwargs: Other parameters.
8990
"""
90-
raise NotImplementedError
91+
pass
9192

9293
def set_params(self, **kwargs) -> None:
9394
"""

art/defences/postprocessor/reverse_sigmoid.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,11 @@ def __init__(
4747
:param apply_fit: True if applied during fitting/training.
4848
:param apply_predict: True if applied during predicting.
4949
"""
50-
super().__init__()
51-
self._is_fitted = True
52-
self._apply_fit = apply_fit
53-
self._apply_predict = apply_predict
50+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
5451
self.beta = beta
5552
self.gamma = gamma
5653
self._check_params()
5754

58-
@property
59-
def apply_fit(self) -> bool:
60-
return self._apply_fit
61-
62-
@property
63-
def apply_predict(self) -> bool:
64-
return self._apply_predict
65-
6655
def __call__(self, preds: np.ndarray) -> np.ndarray:
6756
"""
6857
Perform model postprocessing and return postprocessed output.
@@ -109,12 +98,6 @@ def sigmoid(var_z):
10998

11099
return reverse_sigmoid
111100

112-
def fit(self, preds: np.ndarray, **kwargs) -> None:
113-
"""
114-
No parameters to learn for this method; do nothing.
115-
"""
116-
pass
117-
118101
def _check_params(self) -> None:
119102
if self.beta <= 0:
120103
raise ValueError("Magnitude parameter must be positive.")

art/defences/postprocessor/rounded.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,10 @@ def __init__(self, decimals: int = 3, apply_fit: bool = False, apply_predict: bo
4242
:param apply_fit: True if applied during fitting/training.
4343
:param apply_predict: True if applied during predicting.
4444
"""
45-
super().__init__()
46-
self._is_fitted = True
47-
self._apply_fit = apply_fit
48-
self._apply_predict = apply_predict
45+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
4946
self.decimals = decimals
5047
self._check_params()
5148

52-
@property
53-
def apply_fit(self) -> bool:
54-
return self._apply_fit
55-
56-
@property
57-
def apply_predict(self) -> bool:
58-
return self._apply_predict
59-
6049
def __call__(self, preds: np.ndarray) -> np.ndarray:
6150
"""
6251
Perform model postprocessing and return postprocessed output.
@@ -66,12 +55,6 @@ def __call__(self, preds: np.ndarray) -> np.ndarray:
6655
"""
6756
return np.around(preds, decimals=self.decimals)
6857

69-
def fit(self, preds: np.ndarray, **kwargs) -> None:
70-
"""
71-
No parameters to learn for this method; do nothing.
72-
"""
73-
pass
74-
7558
def _check_params(self) -> None:
7659
if not isinstance(self.decimals, (int, np.int)) or self.decimals <= 0:
7760
raise ValueError("Number of decimal places must be a positive integer.")

art/defences/preprocessor/feature_squeezing.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,11 @@ def __init__(
6262
:param apply_fit: True if applied during fitting/training.
6363
:param apply_predict: True if applied during predicting.
6464
"""
65-
super().__init__()
66-
self._is_fitted = True
67-
self._apply_fit = apply_fit
68-
self._apply_predict = apply_predict
65+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
6966
self.clip_values = clip_values
7067
self.bit_depth = bit_depth
7168
self._check_params()
7269

73-
@property
74-
def apply_fit(self) -> bool:
75-
return self._apply_fit
76-
77-
@property
78-
def apply_predict(self) -> bool:
79-
return self._apply_predict
80-
8170
def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
8271
"""
8372
Apply feature squeezing to sample `x`.
@@ -97,15 +86,6 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
9786

9887
return res, y
9988

100-
def estimate_gradient(self, x: np.ndarray, grad: np.ndarray) -> np.ndarray:
101-
return grad
102-
103-
def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> None:
104-
"""
105-
No parameters to learn for this method; do nothing.
106-
"""
107-
pass
108-
10989
def _check_params(self) -> None:
11090
if not isinstance(self.bit_depth, (int, np.int)) or self.bit_depth <= 0 or self.bit_depth > 64:
11191
raise ValueError("The bit depth must be between 1 and 64.")

art/defences/preprocessor/gaussian_augmentation.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,20 @@ def __init__(
7373
:param apply_fit: True if applied during fitting/training.
7474
:param apply_predict: True if applied during predicting.
7575
"""
76-
super().__init__()
77-
self._is_fitted = True
76+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
7877
if augmentation and not apply_fit and apply_predict:
7978
raise ValueError(
8079
"If `augmentation` is `True`, then `apply_fit` must be `True` and `apply_predict` must be `False`."
8180
)
8281
if augmentation and not (apply_fit or apply_predict):
8382
raise ValueError("If `augmentation` is `True`, then `apply_fit` and `apply_predict` can't be both `False`.")
8483

85-
self._apply_fit = apply_fit
86-
self._apply_predict = apply_predict
8784
self.sigma = sigma
8885
self.augmentation = augmentation
8986
self.ratio = ratio
9087
self.clip_values = clip_values
9188
self._check_params()
9289

93-
@property
94-
def apply_fit(self) -> bool:
95-
return self._apply_fit
96-
97-
@property
98-
def apply_predict(self) -> bool:
99-
return self._apply_predict
100-
10190
def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
10291
"""
10392
Augment the sample `(x, y)` with Gaussian noise. The result is either an extended dataset containing the
@@ -134,15 +123,6 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
134123

135124
return x_aug, y_aug
136125

137-
def estimate_gradient(self, x: np.ndarray, grad: np.ndarray) -> np.ndarray:
138-
return grad
139-
140-
def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> None:
141-
"""
142-
No parameters to learn for this method; do nothing.
143-
"""
144-
pass
145-
146126
def _check_params(self) -> None:
147127
if self.augmentation and self.ratio <= 0:
148128
raise ValueError("The augmentation ratio must be positive.")

art/defences/preprocessor/inverse_gan.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,7 @@ def __init__(
6868
"""
6969
import tensorflow as tf # lgtm [py/repeated-import]
7070

71-
super().__init__()
72-
73-
self._is_fitted = True
74-
self._apply_fit = apply_fit
75-
self._apply_predict = apply_predict
71+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
7672
self.gan = gan
7773
self.inverse_gan = inverse_gan
7874
self.sess = sess
@@ -164,14 +160,6 @@ def loss(self, z_encoding: np.ndarray, image_adv: np.ndarray) -> np.ndarray:
164160
loss = self.sess.run(self._loss, feed_dict={self.gan.input_ph: z_encoding, self._image_adv: image_adv})
165161
return loss
166162

167-
@property
168-
def apply_fit(self) -> bool:
169-
return self._apply_fit
170-
171-
@property
172-
def apply_predict(self) -> bool:
173-
return self._apply_predict
174-
175163
def estimate_gradient(self, z_encoding: np.ndarray, y: np.ndarray) -> np.ndarray:
176164
"""
177165
Compute the gradient of the loss function w.r.t. a `z_encoding` input within a GAN against a
@@ -186,12 +174,6 @@ def estimate_gradient(self, z_encoding: np.ndarray, y: np.ndarray) -> np.ndarray
186174
gradient = self.sess.run(self._grad, feed_dict={self._image_adv: y, self.gan.input_ph: z_encoding})
187175
return gradient
188176

189-
def fit(self, x, y=None, **kwargs):
190-
"""
191-
No parameters to learn for this method; do nothing.
192-
"""
193-
pass
194-
195177
def _check_params(self) -> None:
196178
if self.inverse_gan is not None and self.gan.encoding_length != self.inverse_gan.encoding_length:
197179
raise ValueError("Both GAN and InverseGAN must use the same size encoding.")

0 commit comments

Comments
 (0)