@@ -47,7 +47,7 @@ class BaseEstimator(ABC):
4747 "clip_values" ,
4848 "preprocessing_defences" ,
4949 "postprocessing_defences" ,
50- # "preprocessing", # preprocessing cannot be set with set_params
50+ "preprocessing" ,
5151 ]
5252
5353 def __init__ (
@@ -56,7 +56,7 @@ def __init__(
5656 clip_values : Optional ["CLIP_VALUES_TYPE" ],
5757 preprocessing_defences : Union ["Preprocessor" , List ["Preprocessor" ], None ] = None ,
5858 postprocessing_defences : Union ["Postprocessor" , List ["Postprocessor" ], None ] = None ,
59- preprocessing : "PREPROCESSING_TYPE" = (0 , 1 ),
59+ preprocessing : Union [ "PREPROCESSING_TYPE" , "Preprocessor" ] = (0 , 1 ),
6060 ):
6161 """
6262 Initialize a `BaseEstimator` object.
@@ -77,7 +77,7 @@ def __init__(
7777 self ._model = model
7878 self ._clip_values = clip_values
7979
80- self .preprocessing = preprocessing
80+ self .preprocessing = self . _set_preprocessing ( preprocessing )
8181 self .preprocessing_defences = self ._set_preprocessing_defences (preprocessing_defences )
8282 self .postprocessing_defences = self ._set_postprocessing_defences (postprocessing_defences )
8383 self .preprocessing_operations : List ["Preprocessor" ] = []
@@ -109,6 +109,23 @@ def _update_preprocessing_operations(self):
109109 else :
110110 raise ValueError ("Preprocessing argument not recognised." )
111111
112+ @staticmethod
113+ def _set_preprocessing (preprocessing : Union ["PREPROCESSING_TYPE" , "Preprocessor" ]) -> "Preprocessor" :
114+ from art .defences .preprocessor .preprocessor import Preprocessor
115+
116+ if preprocessing is None :
117+ from art .preprocessing .standardisation_mean_std .standardisation_mean_std import StandardisationMeanStd
118+
119+ return StandardisationMeanStd (mean = 0.0 , std = 1.0 )
120+ elif isinstance (preprocessing , tuple ):
121+ from art .preprocessing .standardisation_mean_std .standardisation_mean_std import StandardisationMeanStd
122+
123+ return StandardisationMeanStd (mean = preprocessing [0 ], std = preprocessing [1 ])
124+ elif isinstance (preprocessing , Preprocessor ):
125+ return preprocessing
126+ else :
127+ raise ValueError ("Preprocessing argument not recognised." )
128+
112129 @staticmethod
113130 def _set_preprocessing_defences (
114131 preprocessing_defences : Optional [Union ["Preprocessor" , List ["Preprocessor" ]]]
@@ -142,10 +159,12 @@ def set_params(self, **kwargs) -> None:
142159 if hasattr (BaseEstimator , key ) and isinstance (getattr (BaseEstimator , key ), property ):
143160 setattr (self , "_" + key , value )
144161 else :
145- if key == "preprocessing_defences" :
146- self ._set_preprocessing_defences (value )
162+ if key == "preprocessing" :
163+ setattr (self , key , self ._set_preprocessing (value ))
164+ elif key == "preprocessing_defences" :
165+ setattr (self , key , self ._set_preprocessing_defences (value ))
147166 elif key == "postprocessing_defences" :
148- self ._set_postprocessing_defences (value )
167+ setattr ( self , key , self ._set_postprocessing_defences (value ) )
149168 else :
150169 setattr (self , key , value )
151170 else :
0 commit comments