Skip to content

Commit fdc4532

Browse files
authored
Merge pull request #901 from Trusted-AI/development_issue_888
Update set_params in BaseEstimator
2 parents 57a5e56 + 50ef2b9 commit fdc4532

File tree

10 files changed

+163
-75
lines changed

10 files changed

+163
-75
lines changed

art/estimators/estimator.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
clip_values: Optional["CLIP_VALUES_TYPE"],
5757
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
5858
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
59-
preprocessing: "PREPROCESSING_TYPE" = (0, 1),
59+
preprocessing: Union["PREPROCESSING_TYPE", "Preprocessor"] = (0, 1),
6060
):
6161
"""
6262
Initialize a `BaseEstimator` object.
@@ -72,41 +72,81 @@ def __init__(
7272
used for data preprocessing. The first value will be subtracted from the input and the results will be
7373
divided by the second value.
7474
"""
75-
from art.defences.postprocessor.postprocessor import Postprocessor
7675
from art.defences.preprocessor.preprocessor import Preprocessor
7776

7877
self._model = model
7978
self._clip_values = clip_values
8079

81-
self.preprocessing: List["Preprocessor"] = []
80+
self.preprocessing = self._set_preprocessing(preprocessing)
81+
self.preprocessing_defences = self._set_preprocessing_defences(preprocessing_defences)
82+
self.postprocessing_defences = self._set_postprocessing_defences(postprocessing_defences)
83+
self.preprocessing_operations: List["Preprocessor"] = []
84+
self._update_preprocessing_operations()
85+
self._check_params()
86+
87+
def _update_preprocessing_operations(self):
88+
from art.defences.preprocessor.preprocessor import Preprocessor
89+
90+
self.preprocessing_operations.clear()
91+
92+
if self.preprocessing_defences is None:
93+
pass
94+
elif isinstance(self.preprocessing_defences, Preprocessor):
95+
self.preprocessing_operations.append(self.preprocessing_defences)
96+
else:
97+
self.preprocessing_operations += self.preprocessing_defences
8298

83-
if preprocessing_defences is None:
99+
if self.preprocessing is None:
84100
pass
85-
elif isinstance(preprocessing_defences, Preprocessor):
86-
self.preprocessing.append(preprocessing_defences)
101+
elif isinstance(self.preprocessing, tuple):
102+
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
103+
104+
self.preprocessing_operations.append(
105+
StandardisationMeanStd(mean=self.preprocessing[0], std=self.preprocessing[1])
106+
)
107+
elif isinstance(self.preprocessing, Preprocessor):
108+
self.preprocessing_operations.append(self.preprocessing)
87109
else:
88-
self.preprocessing += preprocessing_defences
110+
raise ValueError("Preprocessing argument not recognised.")
89111

90-
self.preprocessing_defences = preprocessing_defences
112+
@staticmethod
113+
def _set_preprocessing(preprocessing: Union["PREPROCESSING_TYPE", "Preprocessor"]) -> "Preprocessor":
114+
from art.defences.preprocessor.preprocessor import Preprocessor
91115

92116
if preprocessing is None:
93-
pass
117+
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
118+
119+
return StandardisationMeanStd(mean=0.0, std=1.0)
94120
elif isinstance(preprocessing, tuple):
95121
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
96122

97-
self.preprocessing.append(StandardisationMeanStd(mean=preprocessing[0], std=preprocessing[1]))
123+
return StandardisationMeanStd(mean=preprocessing[0], std=preprocessing[1])
98124
elif isinstance(preprocessing, Preprocessor):
99-
self.preprocessing.append(preprocessing)
125+
return preprocessing
100126
else:
101-
self.preprocessing += preprocessing
127+
raise ValueError("Preprocessing argument not recognised.")
102128

103-
self.postprocessing_defences: Optional[List["Postprocessor"]]
104-
if isinstance(postprocessing_defences, Postprocessor):
105-
self.postprocessing_defences = [postprocessing_defences]
129+
@staticmethod
130+
def _set_preprocessing_defences(
131+
preprocessing_defences: Optional[Union["Preprocessor", List["Preprocessor"]]]
132+
) -> Optional[List["Preprocessor"]]:
133+
from art.defences.preprocessor.preprocessor import Preprocessor
134+
135+
if isinstance(preprocessing_defences, Preprocessor):
136+
return [preprocessing_defences]
106137
else:
107-
self.postprocessing_defences = postprocessing_defences
138+
return preprocessing_defences
108139

109-
self._check_params()
140+
@staticmethod
141+
def _set_postprocessing_defences(
142+
postprocessing_defences: Optional[Union["Postprocessor", List["Postprocessor"]]]
143+
) -> Optional[List["Postprocessor"]]:
144+
from art.defences.postprocessor.postprocessor import Postprocessor
145+
146+
if isinstance(postprocessing_defences, Postprocessor):
147+
return [postprocessing_defences]
148+
else:
149+
return postprocessing_defences
110150

111151
def set_params(self, **kwargs) -> None:
112152
"""
@@ -119,9 +159,17 @@ def set_params(self, **kwargs) -> None:
119159
if hasattr(BaseEstimator, key) and isinstance(getattr(BaseEstimator, key), property):
120160
setattr(self, "_" + key, value)
121161
else:
122-
setattr(self, key, value)
162+
if key == "preprocessing":
163+
setattr(self, key, self._set_preprocessing(value))
164+
elif key == "preprocessing_defences":
165+
setattr(self, key, self._set_preprocessing_defences(value))
166+
elif key == "postprocessing_defences":
167+
setattr(self, key, self._set_postprocessing_defences(value))
168+
else:
169+
setattr(self, key, value)
123170
else:
124-
raise ValueError("Unexpected parameter {} found in kwargs.".format(key))
171+
raise ValueError("Unexpected parameter `{}` found in kwargs.".format(key))
172+
self._update_preprocessing_operations()
125173
self._check_params()
126174

127175
def get_params(self) -> Dict[str, Any]:
@@ -152,8 +200,8 @@ def _check_params(self) -> None:
152200
else:
153201
self._clip_values = np.array(self._clip_values, dtype=ART_NUMPY_DTYPE)
154202

155-
if isinstance(self.preprocessing, list):
156-
for preprocess in self.preprocessing:
203+
if isinstance(self.preprocessing_operations, list):
204+
for preprocess in self.preprocessing_operations:
157205
if not isinstance(preprocess, Preprocessor):
158206
raise ValueError(
159207
"All preprocessing defences have to be instance of "
@@ -245,8 +293,8 @@ def _apply_preprocessing(self, x, y, fit: bool) -> Tuple[Any, Any]:
245293
:return: Tuple of `x` and `y` after applying the defences and standardisation.
246294
:rtype: Format as expected by the `model`
247295
"""
248-
if self.preprocessing:
249-
for preprocess in self.preprocessing:
296+
if self.preprocessing_operations:
297+
for preprocess in self.preprocessing_operations:
250298
if fit:
251299
if preprocess.apply_fit:
252300
x, y = preprocess(x, y)
@@ -322,8 +370,8 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
322370
:return: Gradients after backward pass through normalization and preprocessing defences.
323371
:rtype: Format as expected by the `model`
324372
"""
325-
if self.preprocessing:
326-
for preprocess in self.preprocessing[::-1]:
373+
if self.preprocessing_operations:
374+
for preprocess in self.preprocessing_operations[::-1]:
327375
if fit:
328376
if preprocess.apply_fit:
329377
gradients = preprocess.estimate_gradient(x, gradients)

art/estimators/pytorch.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ def __init__(self, device_type: str = "gpu", **kwargs) -> None:
6565

6666
super().__init__(**kwargs)
6767

68-
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
69-
70-
self.all_framework_preprocessing = all([isinstance(p, PreprocessorPyTorch) for p in self.preprocessing])
71-
7268
# Set device
7369
if device_type == "cpu" or not torch.cuda.is_available():
7470
self._device = torch.device("cpu")
@@ -114,6 +110,23 @@ def loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
114110
"""
115111
raise NotImplementedError
116112

113+
def set_params(self, **kwargs) -> None:
114+
"""
115+
Take a dictionary of parameters and apply checks before setting them as attributes.
116+
117+
:param kwargs: A dictionary of attributes.
118+
"""
119+
super().set_params(**kwargs)
120+
self._check_params()
121+
122+
def _check_params(self) -> None:
123+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
124+
125+
super()._check_params()
126+
self.all_framework_preprocessing = all(
127+
[isinstance(p, PreprocessorPyTorch) for p in self.preprocessing_operations]
128+
)
129+
117130
def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[Any, Any]:
118131
"""
119132
Apply all preprocessing defences of the estimator on the raw inputs `x` and `y`. This function is should
@@ -142,7 +155,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[A
142155
StandardisationMeanStdPyTorch,
143156
)
144157

145-
if not self.preprocessing:
158+
if not self.preprocessing_operations:
146159
return x, y
147160

148161
if isinstance(x, torch.Tensor):
@@ -158,7 +171,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[A
158171
y = torch.tensor(y, device=self._device)
159172

160173
def chain_processes(x, y):
161-
for preprocess in self.preprocessing:
174+
for preprocess in self.preprocessing_operations:
162175
if fit:
163176
if preprocess.apply_fit:
164177
x, y = preprocess.forward(x, y)
@@ -179,12 +192,12 @@ def chain_processes(x, y):
179192
if y is not None:
180193
y = y.cpu().numpy()
181194

182-
elif len(self.preprocessing) == 1 or (
183-
len(self.preprocessing) == 2
184-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
195+
elif len(self.preprocessing_operations) == 1 or (
196+
len(self.preprocessing_operations) == 2
197+
and isinstance(self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
185198
):
186199
# Compatible with non-PyTorch defences if no chaining.
187-
for preprocess in self.preprocessing:
200+
for preprocess in self.preprocessing_operations:
188201
x, y = preprocess(x, y)
189202

190203
else:
@@ -218,7 +231,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
218231
StandardisationMeanStdPyTorch,
219232
)
220233

221-
if not self.preprocessing:
234+
if not self.preprocessing_operations:
222235
return gradients
223236

224237
if isinstance(x, torch.Tensor):
@@ -232,7 +245,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
232245
gradients = torch.tensor(gradients, device=self._device)
233246
x_orig = x
234247

235-
for preprocess in self.preprocessing:
248+
for preprocess in self.preprocessing_operations:
236249
if fit:
237250
if preprocess.apply_fit:
238251
x = preprocess.estimate_forward(x)
@@ -249,13 +262,18 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
249262
"The input shape is {} while the gradient shape is {}".format(x.shape, gradients.shape)
250263
)
251264

252-
elif len(self.preprocessing) == 1 or (
253-
len(self.preprocessing) == 2
254-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
265+
elif len(self.preprocessing_operations) == 1 or (
266+
len(self.preprocessing_operations) == 2
267+
and isinstance(self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdPyTorch))
255268
):
256269
# Compatible with non-PyTorch defences if no chaining.
257-
defence = self.preprocessing[0]
258-
gradients = defence.estimate_gradient(x, gradients)
270+
for preprocess in self.preprocessing_operations[::-1]:
271+
if fit:
272+
if preprocess.apply_fit:
273+
gradients = preprocess.estimate_gradient(x, gradients)
274+
else:
275+
if preprocess.apply_predict:
276+
gradients = preprocess.estimate_gradient(x, gradients)
259277

260278
else:
261279
raise NotImplementedError("The current combination of preprocessing types is not supported.")

art/estimators/tensorflow.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,6 @@ def __init__(self, **kwargs):
118118

119119
super().__init__(**kwargs)
120120

121-
from art.defences.preprocessor.preprocessor import PreprocessorTensorFlowV2
122-
123-
self.all_framework_preprocessing = all([isinstance(p, PreprocessorTensorFlowV2) for p in self.preprocessing])
124-
125121
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs):
126122
"""
127123
Perform prediction of the neural network for samples `x`.
@@ -160,6 +156,23 @@ def loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
160156
"""
161157
raise NotImplementedError
162158

159+
def set_params(self, **kwargs) -> None:
160+
"""
161+
Take a dictionary of parameters and apply checks before setting them as attributes.
162+
163+
:param kwargs: A dictionary of attributes.
164+
"""
165+
super().set_params(**kwargs)
166+
self._check_params()
167+
168+
def _check_params(self) -> None:
169+
from art.defences.preprocessor.preprocessor import PreprocessorTensorFlowV2
170+
171+
super()._check_params()
172+
self.all_framework_preprocessing = all(
173+
[isinstance(p, PreprocessorTensorFlowV2) for p in self.preprocessing_operations]
174+
)
175+
163176
def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
164177
"""
165178
Apply all preprocessing defences of the estimator on the raw inputs `x` and `y`. This function is should
@@ -186,7 +199,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
186199
StandardisationMeanStdTensorFlowV2,
187200
)
188201

189-
if not self.preprocessing:
202+
if not self.preprocessing_operations:
190203
return x, y
191204

192205
if isinstance(x, tf.Tensor):
@@ -201,7 +214,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
201214
if y is not None:
202215
y = tf.convert_to_tensor(y)
203216

204-
for preprocess in self.preprocessing:
217+
for preprocess in self.preprocessing_operations:
205218
if fit:
206219
if preprocess.apply_fit:
207220
x, y = preprocess.forward(x, y)
@@ -215,12 +228,14 @@ def _apply_preprocessing(self, x, y, fit: bool = False) -> Tuple[Any, Any]:
215228
if y is not None:
216229
y = y.numpy()
217230

218-
elif len(self.preprocessing) == 1 or (
219-
len(self.preprocessing) == 2
220-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2))
231+
elif len(self.preprocessing_operations) == 1 or (
232+
len(self.preprocessing_operations) == 2
233+
and isinstance(
234+
self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2)
235+
)
221236
):
222237
# Compatible with non-TensorFlow defences if no chaining.
223-
for preprocess in self.preprocessing:
238+
for preprocess in self.preprocessing_operations:
224239
x, y = preprocess(x, y)
225240

226241
else:
@@ -254,7 +269,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
254269
StandardisationMeanStdTensorFlowV2,
255270
)
256271

257-
if not self.preprocessing:
272+
if not self.preprocessing_operations:
258273
return gradients
259274

260275
if isinstance(x, tf.Tensor):
@@ -270,7 +285,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
270285
gradients = tf.convert_to_tensor(gradients, dtype=config.ART_NUMPY_DTYPE)
271286
x_orig = x
272287

273-
for preprocess in self.preprocessing:
288+
for preprocess in self.preprocessing_operations:
274289
if fit:
275290
if preprocess.apply_fit:
276291
x = preprocess.estimate_forward(x)
@@ -287,13 +302,20 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
287302
"The input shape is {} while the gradient shape is {}".format(x.shape, gradients.shape)
288303
)
289304

290-
elif len(self.preprocessing) == 1 or (
291-
len(self.preprocessing) == 2
292-
and isinstance(self.preprocessing[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2))
305+
elif len(self.preprocessing_operations) == 1 or (
306+
len(self.preprocessing_operations) == 2
307+
and isinstance(
308+
self.preprocessing_operations[-1], (StandardisationMeanStd, StandardisationMeanStdTensorFlowV2)
309+
)
293310
):
294311
# Compatible with non-TensorFlow defences if no chaining.
295-
defence = self.preprocessing[0]
296-
gradients = defence.estimate_gradient(x, gradients)
312+
for preprocess in self.preprocessing_operations[::-1]:
313+
if fit:
314+
if preprocess.apply_fit:
315+
gradients = preprocess.estimate_gradient(x, gradients)
316+
else:
317+
if preprocess.apply_predict:
318+
gradients = preprocess.estimate_gradient(x, gradients)
297319

298320
else:
299321
raise NotImplementedError("The current combination of preprocessing types is not supported.")

0 commit comments

Comments
 (0)