@@ -56,7 +56,7 @@ def __init__(
5656 clip_values : Optional ["CLIP_VALUES_TYPE" ],
5757 preprocessing_defences : Union ["Preprocessor" , List ["Preprocessor" ], None ] = None ,
5858 postprocessing_defences : Union ["Postprocessor" , List ["Postprocessor" ], None ] = None ,
59- preprocessing : "PREPROCESSING_TYPE" = (0 , 1 ),
59+ preprocessing : Union [ "PREPROCESSING_TYPE" , "Preprocessor" ] = (0 , 1 ),
6060 ):
6161 """
6262 Initialize a `BaseEstimator` object.
@@ -72,41 +72,81 @@ def __init__(
7272 used for data preprocessing. The first value will be subtracted from the input and the results will be
7373 divided by the second value.
7474 """
75- from art .defences .postprocessor .postprocessor import Postprocessor
7675 from art .defences .preprocessor .preprocessor import Preprocessor
7776
7877 self ._model = model
7978 self ._clip_values = clip_values
8079
81- self .preprocessing : List ["Preprocessor" ] = []
80+ self .preprocessing = self ._set_preprocessing (preprocessing )
81+ self .preprocessing_defences = self ._set_preprocessing_defences (preprocessing_defences )
82+ self .postprocessing_defences = self ._set_postprocessing_defences (postprocessing_defences )
83+ self .preprocessing_operations : List ["Preprocessor" ] = []
84+ self ._update_preprocessing_operations ()
85+ self ._check_params ()
86+
87+ def _update_preprocessing_operations (self ):
88+ from art .defences .preprocessor .preprocessor import Preprocessor
89+
90+ self .preprocessing_operations .clear ()
91+
92+ if self .preprocessing_defences is None :
93+ pass
94+ elif isinstance (self .preprocessing_defences , Preprocessor ):
95+ self .preprocessing_operations .append (self .preprocessing_defences )
96+ else :
97+ self .preprocessing_operations += self .preprocessing_defences
8298
83- if preprocessing_defences is None :
99+ if self . preprocessing is None :
84100 pass
85- elif isinstance (preprocessing_defences , Preprocessor ):
86- self .preprocessing .append (preprocessing_defences )
101+ elif isinstance (self .preprocessing , tuple ):
102+ from art .preprocessing .standardisation_mean_std .standardisation_mean_std import StandardisationMeanStd
103+
104+ self .preprocessing_operations .append (
105+ StandardisationMeanStd (mean = self .preprocessing [0 ], std = self .preprocessing [1 ])
106+ )
107+ elif isinstance (self .preprocessing , Preprocessor ):
108+ self .preprocessing_operations .append (self .preprocessing )
87109 else :
88- self . preprocessing += preprocessing_defences
110+ raise ValueError ( "Preprocessing argument not recognised." )
89111
90- self .preprocessing_defences = preprocessing_defences
112+ @staticmethod
113+ def _set_preprocessing (preprocessing : Union ["PREPROCESSING_TYPE" , "Preprocessor" ]) -> "Preprocessor" :
114+ from art .defences .preprocessor .preprocessor import Preprocessor
91115
92116 if preprocessing is None :
93- pass
117+ from art .preprocessing .standardisation_mean_std .standardisation_mean_std import StandardisationMeanStd
118+
119+ return StandardisationMeanStd (mean = 0.0 , std = 1.0 )
94120 elif isinstance (preprocessing , tuple ):
95121 from art .preprocessing .standardisation_mean_std .standardisation_mean_std import StandardisationMeanStd
96122
97- self . preprocessing . append ( StandardisationMeanStd (mean = preprocessing [0 ], std = preprocessing [1 ]) )
123+ return StandardisationMeanStd (mean = preprocessing [0 ], std = preprocessing [1 ])
98124 elif isinstance (preprocessing , Preprocessor ):
99- self . preprocessing . append ( preprocessing )
125+ return preprocessing
100126 else :
101- self . preprocessing += preprocessing
127+ raise ValueError ( "Preprocessing argument not recognised." )
102128
103- self .postprocessing_defences : Optional [List ["Postprocessor" ]]
104- if isinstance (postprocessing_defences , Postprocessor ):
105- self .postprocessing_defences = [postprocessing_defences ]
129+ @staticmethod
130+ def _set_preprocessing_defences (
131+ preprocessing_defences : Optional [Union ["Preprocessor" , List ["Preprocessor" ]]]
132+ ) -> Optional [List ["Preprocessor" ]]:
133+ from art .defences .preprocessor .preprocessor import Preprocessor
134+
135+ if isinstance (preprocessing_defences , Preprocessor ):
136+ return [preprocessing_defences ]
106137 else :
107- self . postprocessing_defences = postprocessing_defences
138+ return preprocessing_defences
108139
109- self ._check_params ()
140+ @staticmethod
141+ def _set_postprocessing_defences (
142+ postprocessing_defences : Optional [Union ["Postprocessor" , List ["Postprocessor" ]]]
143+ ) -> Optional [List ["Postprocessor" ]]:
144+ from art .defences .postprocessor .postprocessor import Postprocessor
145+
146+ if isinstance (postprocessing_defences , Postprocessor ):
147+ return [postprocessing_defences ]
148+ else :
149+ return postprocessing_defences
110150
111151 def set_params (self , ** kwargs ) -> None :
112152 """
@@ -119,9 +159,17 @@ def set_params(self, **kwargs) -> None:
119159 if hasattr (BaseEstimator , key ) and isinstance (getattr (BaseEstimator , key ), property ):
120160 setattr (self , "_" + key , value )
121161 else :
122- setattr (self , key , value )
162+ if key == "preprocessing" :
163+ setattr (self , key , self ._set_preprocessing (value ))
164+ elif key == "preprocessing_defences" :
165+ setattr (self , key , self ._set_preprocessing_defences (value ))
166+ elif key == "postprocessing_defences" :
167+ setattr (self , key , self ._set_postprocessing_defences (value ))
168+ else :
169+ setattr (self , key , value )
123170 else :
124- raise ValueError ("Unexpected parameter {} found in kwargs." .format (key ))
171+ raise ValueError ("Unexpected parameter `{}` found in kwargs." .format (key ))
172+ self ._update_preprocessing_operations ()
125173 self ._check_params ()
126174
127175 def get_params (self ) -> Dict [str , Any ]:
@@ -152,8 +200,8 @@ def _check_params(self) -> None:
152200 else :
153201 self ._clip_values = np .array (self ._clip_values , dtype = ART_NUMPY_DTYPE )
154202
155- if isinstance (self .preprocessing , list ):
156- for preprocess in self .preprocessing :
203+ if isinstance (self .preprocessing_operations , list ):
204+ for preprocess in self .preprocessing_operations :
157205 if not isinstance (preprocess , Preprocessor ):
158206 raise ValueError (
159207 "All preprocessing defences have to be instance of "
@@ -245,8 +293,8 @@ def _apply_preprocessing(self, x, y, fit: bool) -> Tuple[Any, Any]:
245293 :return: Tuple of `x` and `y` after applying the defences and standardisation.
246294 :rtype: Format as expected by the `model`
247295 """
248- if self .preprocessing :
249- for preprocess in self .preprocessing :
296+ if self .preprocessing_operations :
297+ for preprocess in self .preprocessing_operations :
250298 if fit :
251299 if preprocess .apply_fit :
252300 x , y = preprocess (x , y )
@@ -322,8 +370,8 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
322370 :return: Gradients after backward pass through normalization and preprocessing defences.
323371 :rtype: Format as expected by the `model`
324372 """
325- if self .preprocessing :
326- for preprocess in self .preprocessing [::- 1 ]:
373+ if self .preprocessing_operations :
374+ for preprocess in self .preprocessing_operations [::- 1 ]:
327375 if fit :
328376 if preprocess .apply_fit :
329377 gradients = preprocess .estimate_gradient (x , gradients )
0 commit comments