Skip to content

Commit 5319d95

Browse files
author
Beat Buesser
committed
Updates for BaseEstimator.set_params
Signed-off-by: Beat Buesser <[email protected]>
1 parent d1f97b7 commit 5319d95

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

art/estimators/estimator.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class BaseEstimator(ABC):
4747
"clip_values",
4848
"preprocessing_defences",
4949
"postprocessing_defences",
50-
# "preprocessing", # preprocessing cannot be set with set_params
50+
"preprocessing",
5151
]
5252

5353
def __init__(
@@ -56,7 +56,7 @@ def __init__(
5656
clip_values: Optional["CLIP_VALUES_TYPE"],
5757
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
5858
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
59-
preprocessing: "PREPROCESSING_TYPE" = (0, 1),
59+
preprocessing: Union["PREPROCESSING_TYPE", "Preprocessor"] = (0, 1),
6060
):
6161
"""
6262
Initialize a `BaseEstimator` object.
@@ -77,7 +77,7 @@ def __init__(
7777
self._model = model
7878
self._clip_values = clip_values
7979

80-
self.preprocessing = preprocessing
80+
self.preprocessing = self._set_preprocessing(preprocessing)
8181
self.preprocessing_defences = self._set_preprocessing_defences(preprocessing_defences)
8282
self.postprocessing_defences = self._set_postprocessing_defences(postprocessing_defences)
8383
self.preprocessing_operations: List["Preprocessor"] = []
@@ -109,6 +109,23 @@ def _update_preprocessing_operations(self):
109109
else:
110110
raise ValueError("Preprocessing argument not recognised.")
111111

112+
@staticmethod
113+
def _set_preprocessing(preprocessing: Union["PREPROCESSING_TYPE", "Preprocessor"]) -> "Preprocessor":
114+
from art.defences.preprocessor.preprocessor import Preprocessor
115+
116+
if preprocessing is None:
117+
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
118+
119+
return StandardisationMeanStd(mean=0.0, std=1.0)
120+
elif isinstance(preprocessing, tuple):
121+
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
122+
123+
return StandardisationMeanStd(mean=preprocessing[0], std=preprocessing[1])
124+
elif isinstance(preprocessing, Preprocessor):
125+
return preprocessing
126+
else:
127+
raise ValueError("Preprocessing argument not recognised.")
128+
112129
@staticmethod
113130
def _set_preprocessing_defences(
114131
preprocessing_defences: Optional[Union["Preprocessor", List["Preprocessor"]]]
@@ -142,10 +159,12 @@ def set_params(self, **kwargs) -> None:
142159
if hasattr(BaseEstimator, key) and isinstance(getattr(BaseEstimator, key), property):
143160
setattr(self, "_" + key, value)
144161
else:
145-
if key == "preprocessing_defences":
146-
self._set_preprocessing_defences(value)
162+
if key == "preprocessing":
163+
setattr(self, key, self._set_preprocessing(value))
164+
elif key == "preprocessing_defences":
165+
setattr(self, key, self._set_preprocessing_defences(value))
147166
elif key == "postprocessing_defences":
148-
self._set_postprocessing_defences(value)
167+
setattr(self, key, self._set_postprocessing_defences(value))
149168
else:
150169
setattr(self, key, value)
151170
else:

0 commit comments

Comments
 (0)