@@ -59,6 +59,7 @@ def __init__(
5959 attack_model_type : str = "nn" ,
6060 attack_model : Optional ["CLASSIFIER_TYPE" ] = None ,
6161 attack_feature : Union [int , slice ] = 0 ,
62+ is_regression : Optional [bool ] = False ,
6263 scale_range : Optional [slice ] = None ,
6364 prediction_normal_factor : float = 1 ,
6465 ):
@@ -72,11 +73,12 @@ def __init__(
7273 :param attack_feature: The index of the feature to be attacked or a slice representing multiple indexes in
7374 case of a one-hot encoded feature.
7475 case of a one-hot encoded feature.
76+ :param is_regression: Whether the model is a regression model. Default is False (classification).
7577 :param scale_range: If supplied, the class labels (both true and predicted) will be scaled to the given range.
76- Only applicable when `estimator ` is a regressor .
78+ Only applicable when `is_regression ` is True .
7779 :param prediction_normal_factor: If supplied, the class labels (both true and predicted) are multiplied by the
7880 factor when used as inputs to the attack-model. Only applicable when
79- `estimator ` is a regressor and if `scale_range` is not supplied.
81+ `is_regression ` is True and if `scale_range` is not supplied.
8082 """
8183 super ().__init__ (estimator = None , attack_feature = attack_feature )
8284
@@ -119,6 +121,7 @@ def __init__(
119121
120122 self .prediction_normal_factor = prediction_normal_factor
121123 self .scale_range = scale_range
124+ self .is_regression = is_regression
122125 self ._check_params ()
123126 self .attack_feature = get_feature_index (self .attack_feature )
124127
@@ -146,11 +149,14 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> None:
146149 raise ValueError ("None value detected." )
147150
148151 # create training set for attack model
149- if self .scale_range is not None :
150- normalized_labels = minmax_scale (y , feature_range = self .scale_range )
152+ if self .is_regression :
153+ if self .scale_range is not None :
154+ normalized_labels = minmax_scale (y , feature_range = self .scale_range )
155+ else :
156+ normalized_labels = y * self .prediction_normal_factor
157+ normalized_labels = normalized_labels .reshape (- 1 , 1 )
151158 else :
152- normalized_labels = y * self .prediction_normal_factor
153- normalized_labels = check_and_transform_label_format (normalized_labels , return_one_hot = True )
159+ normalized_labels = check_and_transform_label_format (y , return_one_hot = True )
154160 x_train = np .concatenate ((np .delete (x , self .attack_feature , 1 ), normalized_labels ), axis = 1 ).astype (np .float32 )
155161
156162 # train attack model
@@ -179,11 +185,14 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
179185 if values is not None :
180186 self ._values = values
181187
182- if self .scale_range is not None :
183- normalized_labels = minmax_scale (y , feature_range = self .scale_range )
188+ if self .is_regression :
189+ if self .scale_range is not None :
190+ normalized_labels = minmax_scale (y , feature_range = self .scale_range )
191+ else :
192+ normalized_labels = y * self .prediction_normal_factor
193+ normalized_labels = normalized_labels .reshape (- 1 , 1 )
184194 else :
185- normalized_labels = y * self .prediction_normal_factor
186- normalized_labels = check_and_transform_label_format (normalized_labels , return_one_hot = True )
195+ normalized_labels = check_and_transform_label_format (y , return_one_hot = True )
187196 x_test = np .concatenate ((x , normalized_labels ), axis = 1 ).astype (np .float32 )
188197
189198 predictions = self .attack_model .predict (x_test ).astype (np .float32 )
0 commit comments