Skip to content

Commit d1f97b7

Browse files
author
Beat Buesser
committed
Updates for BaseEstimator.set_params
Signed-off-by: Beat Buesser <[email protected]>
1 parent 96957a1 commit d1f97b7

File tree

4 files changed

+92
-66
lines changed

4 files changed

+92
-66
lines changed

art/estimators/estimator.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -72,55 +72,64 @@ def __init__(
7272
used for data preprocessing. The first value will be subtracted from the input and the results will be
7373
divided by the second value.
7474
"""
75-
from art.defences.postprocessor.postprocessor import Postprocessor
7675
from art.defences.preprocessor.preprocessor import Preprocessor
7776

7877
self._model = model
7978
self._clip_values = clip_values
8079

81-
self.preprocessing: List["Preprocessor"] = []
80+
self.preprocessing = preprocessing
81+
self.preprocessing_defences = self._set_preprocessing_defences(preprocessing_defences)
82+
self.postprocessing_defences = self._set_postprocessing_defences(postprocessing_defences)
83+
self.preprocessing_operations: List["Preprocessor"] = []
84+
self._update_preprocessing_operations()
85+
self._check_params()
86+
87+
def _update_preprocessing_operations(self):
88+
from art.defences.preprocessor.preprocessor import Preprocessor
89+
90+
self.preprocessing_operations.clear()
91+
92+
if self.preprocessing_defences is None:
93+
pass
94+
elif isinstance(self.preprocessing_defences, Preprocessor):
95+
self.preprocessing_operations.append(self.preprocessing_defences)
96+
else:
97+
self.preprocessing_operations += self.preprocessing_defences
8298

83-
# preprocessing
84-
self._preprocessing_argument = None
85-
if preprocessing is None:
99+
if self.preprocessing is None:
86100
pass
87-
elif isinstance(preprocessing, tuple):
101+
elif isinstance(self.preprocessing, tuple):
88102
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
89103

90-
self._preprocessing_argument = StandardisationMeanStd(mean=preprocessing[0], std=preprocessing[1])
91-
elif isinstance(preprocessing, Preprocessor):
92-
self._preprocessing_argument = preprocessing
104+
self.preprocessing_operations.append(
105+
StandardisationMeanStd(mean=self.preprocessing[0], std=self.preprocessing[1])
106+
)
107+
elif isinstance(self.preprocessing, Preprocessor):
108+
self.preprocessing_operations.append(self.preprocessing)
93109
else:
94110
raise ValueError("Preprocessing argument not recognised.")
95111

96-
# preprocessing_defences
97-
self._set_preprocessing_defences(preprocessing_defences)
98-
99-
# postprocessing_defences
100-
self.postprocessing_defences: Optional[List["Postprocessor"]]
101-
self._set_postprocessing_defences(postprocessing_defences)
102-
103-
self._check_params()
104-
105-
def _set_preprocessing_defences(self, preprocessing_defences):
112+
@staticmethod
113+
def _set_preprocessing_defences(
114+
preprocessing_defences: Optional[Union["Preprocessor", List["Preprocessor"]]]
115+
) -> Optional[List["Preprocessor"]]:
106116
from art.defences.preprocessor.preprocessor import Preprocessor
107117

108-
if preprocessing_defences is None:
109-
self.preprocessing = [self._preprocessing_argument]
110-
elif isinstance(preprocessing_defences, Preprocessor):
111-
self.preprocessing = [preprocessing_defences] + [self._preprocessing_argument]
118+
if isinstance(preprocessing_defences, Preprocessor):
119+
return [preprocessing_defences]
112120
else:
113-
self.preprocessing = preprocessing_defences + [self._preprocessing_argument]
114-
115-
self.preprocessing_defences = preprocessing_defences
121+
return preprocessing_defences
116122

117-
def _set_postprocessing_defences(self, postprocessing_defences):
123+
@staticmethod
124+
def _set_postprocessing_defences(
125+
postprocessing_defences: Optional[Union["Postprocessor", List["Postprocessor"]]]
126+
) -> Optional[List["Postprocessor"]]:
118127
from art.defences.postprocessor.postprocessor import Postprocessor
119128

120129
if isinstance(postprocessing_defences, Postprocessor):
121-
self.postprocessing_defences = [postprocessing_defences]
130+
return [postprocessing_defences]
122131
else:
123-
self.postprocessing_defences = postprocessing_defences
132+
return postprocessing_defences
124133

125134
def set_params(self, **kwargs) -> None:
126135
"""
@@ -141,6 +150,7 @@ def set_params(self, **kwargs) -> None:
141150
setattr(self, key, value)
142151
else:
143152
raise ValueError("Unexpected parameter `{}` found in kwargs.".format(key))
153+
self._update_preprocessing_operations()
144154
self._check_params()
145155

146156
def get_params(self) -> Dict[str, Any]:
@@ -171,8 +181,8 @@ def _check_params(self) -> None:
171181
else:
172182
self._clip_values = np.array(self._clip_values, dtype=ART_NUMPY_DTYPE)
173183

174-
if isinstance(self.preprocessing, list):
175-
for preprocess in self.preprocessing:
184+
if isinstance(self.preprocessing_operations, list):
185+
for preprocess in self.preprocessing_operations:
176186
if not isinstance(preprocess, Preprocessor):
177187
raise ValueError(
178188
"All preprocessing defences have to be instance of "
@@ -265,7 +275,7 @@ def _apply_preprocessing(self, x, y, fit: bool) -> Tuple[Any, Any]:
265275
:rtype: Format as expected by the `model`
266276
"""
267277
if self.preprocessing:
268-
for preprocess in self.preprocessing:
278+
for preprocess in self.preprocessing_operations:
269279
if fit:
270280
if preprocess.apply_fit:
271281
x, y = preprocess(x, y)
@@ -341,8 +351,8 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
341351
:return: Gradients after backward pass through normalization and preprocessing defences.
342352
:rtype: Format as expected by the `model`
343353
"""
344-
if self.preprocessing:
345-
for preprocess in self.preprocessing[::-1]:
354+
if self.preprocessing_operations:
355+
for preprocess in self.preprocessing_operations[::-1]:
346356
if fit:
347357
if preprocess.apply_fit:
348358
gradients = preprocess.estimate_gradient(x, gradients)

art/estimators/pytorch.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def _check_params(self) -> None:
123123
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
124124

125125
super()._check_params()
126-
self.all_framework_preprocessing = all([isinstance(p, PreprocessorPyTorch) for p in self.preprocessing])
126+
self.all_framework_preprocessing = all(
127+
[isinstance(p, PreprocessorPyTorch) for p in self.preprocessing_operations]
128+
)
127129

128130
def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[Any, Any]:
129131
"""
@@ -153,7 +155,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[A
153155
StandardisationMeanStdPyTorch,
154156
)
155157

156-
if not self.preprocessing:
158+
if not self.preprocessing_operations:
157159
return x, y
158160

159161
if isinstance(x, torch.Tensor):
@@ -169,7 +171,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[A
169171
y = torch.tensor(y, device=self._device)
170172

171173
def chain_processes(x, y):
172-
for preprocess in self.preprocessing:
174+
for preprocess in self.preprocessing_operations:
173175
if fit:
174176
if preprocess.apply_fit:
175177
x, y = preprocess.forward(x, y)
@@ -190,12 +192,12 @@ def chain_processes(x, y):
190192
if y is not None:
191193
y = y.cpu().numpy()
192194

193-
elif len(self.preprocessing) == 1 or (
194-
len(self.preprocessing) == 2
195-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
195+
elif len(self.preprocessing_operations) == 1 or (
196+
len(self.preprocessing_operations) == 2
197+
and isinstance(self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
196198
):
197199
# Compatible with non-PyTorch defences if no chaining.
198-
for preprocess in self.preprocessing:
200+
for preprocess in self.preprocessing_operations:
199201
x, y = preprocess(x, y)
200202

201203
else:
@@ -229,7 +231,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
229231
StandardisationMeanStdPyTorch,
230232
)
231233

232-
if not self.preprocessing:
234+
if not self.preprocessing_operations:
233235
return gradients
234236

235237
if isinstance(x, torch.Tensor):
@@ -243,7 +245,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
243245
gradients = torch.tensor(gradients, device=self._device)
244246
x_orig = x
245247

246-
for preprocess in self.preprocessing:
248+
for preprocess in self.preprocessing_operations:
247249
if fit:
248250
if preprocess.apply_fit:
249251
x = preprocess.estimate_forward(x)
@@ -260,13 +262,18 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
260262
"The input shape is {} while the gradient shape is {}".format(x.shape, gradients.shape)
261263
)
262264

263-
elif len(self.preprocessing) == 1 or (
264-
len(self.preprocessing) == 2
265+
elif len(self.preprocessing_operations) == 1 or (
266+
len(self.preprocessing_operations) == 2
265267
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
266268
):
267269
# Compatible with non-PyTorch defences if no chaining.
268-
defence = self.preprocessing[0]
269-
gradients = defence.estimate_gradient(x, gradients)
270+
for preprocess in self.preprocessing_operations[::-1]:
271+
if fit:
272+
if preprocess.apply_fit:
273+
gradients = preprocess.estimate_gradient(x, gradients)
274+
else:
275+
if preprocess.apply_predict:
276+
gradients = preprocess.estimate_gradient(x, gradients)
270277

271278
else:
272279
raise NotImplementedError("The current combination of preprocessing types is not supported.")

art/estimators/tensorflow.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def _check_params(self) -> None:
169169
from art.defences.preprocessor.preprocessor import PreprocessorTensorFlowV2
170170

171171
super()._check_params()
172-
self.all_framework_preprocessing = all([isinstance(p, PreprocessorTensorFlowV2) for p in self.preprocessing])
172+
self.all_framework_preprocessing = all(
173+
[isinstance(p, PreprocessorTensorFlowV2) for p in self.preprocessing_operations]
174+
)
173175

174176
def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
175177
"""
@@ -197,7 +199,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
197199
StandardisationMeanStdTensorFlowV2,
198200
)
199201

200-
if not self.preprocessing:
202+
if not self.preprocessing_operations:
201203
return x, y
202204

203205
if isinstance(x, tf.Tensor):
@@ -212,7 +214,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
212214
if y is not None:
213215
y = tf.convert_to_tensor(y)
214216

215-
for preprocess in self.preprocessing:
217+
for preprocess in self.preprocessing_operations:
216218
if fit:
217219
if preprocess.apply_fit:
218220
x, y = preprocess.forward(x, y)
@@ -226,12 +228,14 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
226228
if y is not None:
227229
y = y.numpy()
228230

229-
elif len(self.preprocessing) == 1 or (
230-
len(self.preprocessing) == 2
231-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2))
231+
elif len(self.preprocessing_operations) == 1 or (
232+
len(self.preprocessing_operations) == 2
233+
and isinstance(
234+
self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2)
235+
)
232236
):
233237
# Compatible with non-TensorFlow defences if no chaining.
234-
for preprocess in self.preprocessing:
238+
for preprocess in self.preprocessing_operations:
235239
x, y = preprocess(x, y)
236240

237241
else:
@@ -265,7 +269,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
265269
StandardisationMeanStdTensorFlowV2,
266270
)
267271

268-
if not self.preprocessing:
272+
if not self.preprocessing_operations:
269273
return gradients
270274

271275
if isinstance(x, tf.Tensor):
@@ -281,7 +285,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
281285
gradients = tf.convert_to_tensor(gradients, dtype=config.ART_NUMPY_DTYPE)
282286
x_orig = x
283287

284-
for preprocess in self.preprocessing:
288+
for preprocess in self.preprocessing_operations:
285289
if fit:
286290
if preprocess.apply_fit:
287291
x = preprocess.estimate_forward(x)
@@ -298,13 +302,18 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
298302
"The input shape is {} while the gradient shape is {}".format(x.shape, gradients.shape)
299303
)
300304

301-
elif len(self.preprocessing) == 1 or (
302-
len(self.preprocessing) == 2
305+
elif len(self.preprocessing_operations) == 1 or (
306+
len(self.preprocessing_operations) == 2
303307
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2))
304308
):
305309
# Compatible with non-TensorFlow defences if no chaining.
306-
defence = self.preprocessing[0]
307-
gradients = defence.estimate_gradient(x, gradients)
310+
for preprocess in self.preprocessing_operations[::-1]:
311+
if fit:
312+
if preprocess.apply_fit:
313+
gradients = preprocess.estimate_gradient(x, gradients)
314+
else:
315+
if preprocess.apply_predict:
316+
gradients = preprocess.estimate_gradient(x, gradients)
308317

309318
else:
310319
raise NotImplementedError("The current combination of preprocessing types is not supported.")

tests/estimators/classification/test_deeplearning_common.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@
392392
"train_step=<function get_image_classifier_tf_v2.<locals>.train_step",
393393
"channels_first=False, ",
394394
"clip_values=array([0., 1.], dtype=float32), ",
395-
"preprocessing_defences=None, postprocessing_defences=None, preprocessing=[StandardisationMeanStdTensorFlowV2(mean=0, std=1, apply_fit=True, apply_predict=True)]"
395+
"preprocessing_defences=None, postprocessing_defences=None, preprocessing=StandardisationMeanStdTensorFlowV2(mean=0, std=1, apply_fit=True, apply_predict=True)"
396396
],
397397
"test_repr_tensorflow1": [
398398
"TensorFlowClassifier",
@@ -406,7 +406,7 @@
406406
"channels_first=False",
407407
"clip_values=array([0., 1.], dtype=float32)",
408408
"preprocessing_defences=None, postprocessing_defences=None",
409-
"preprocessing=[StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)]"
409+
"preprocessing=StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)"
410410
],
411411
"test_repr_pytorch": [
412412
"art.estimators.classification.pytorch.PyTorchClassifier",
@@ -416,15 +416,15 @@
416416
"loss=CrossEntropyLoss(), optimizer=Adam",
417417
"input_shape=(1, 28, 28), nb_classes=10, channel_index",
418418
"clip_values=array([0., 1.], dtype=float32",
419-
"preprocessing_defences=None, postprocessing_defences=None, preprocessing=[StandardisationMeanStdPyTorch(mean=0, std=1, apply_fit=True, apply_predict=True, device=cpu)]"
419+
"preprocessing_defences=None, postprocessing_defences=None, preprocessing=StandardisationMeanStdPyTorch(mean=0, std=1, apply_fit=True, apply_predict=True, device=cpu)"
420420
],
421421
"test_repr_keras": [
422422
"art.estimators.classification.keras.KerasClassifier",
423423
"use_logits=True",
424424
"channels_first=False",
425425
"clip_values=array([0., 1.], dtype=float32)",
426426
"preprocessing_defences=None",
427-
"preprocessing=[StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)]",
427+
"preprocessing=[(mean=0, std=1, apply_fit=True, apply_predict=True)",
428428
"input_layer=0, output_layer=0"
429429
],
430430
"test_repr_kerastf": [
@@ -433,14 +433,14 @@
433433
"channels_first=False",
434434
"clip_values=array([0., 1.], dtype=float32)",
435435
"preprocessing_defences=None",
436-
"preprocessing=[StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)]",
436+
"preprocessing=StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)",
437437
"input_layer=0, output_layer=0"
438438
],
439439
"test_repr_mxnet": [
440440
"art.estimators.classification.mxnet.MXClassifier",
441441
"input_shape=(1, 28, 28), nb_classes=10",
442442
"channels_first=True, clip_values=array([0., 1.], dtype=float32)",
443-
"defences=None, preprocessing=[StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)]"
443+
"defences=None, preprocessing=StandardisationMeanStd(mean=0, std=1, apply_fit=True, apply_predict=True)"
444444
],
445445
"test_predict_mxnet": [
446446
[

0 commit comments

Comments
 (0)