3232from art .estimators .classification .classifier import ClassifierMixin
3333from art .attacks .attack import AttributeInferenceAttack
3434from art .estimators .regression import RegressorMixin
35- from art .utils import check_and_transform_label_format , float_to_categorical , floats_to_one_hot , get_feature_values
35+ from art .utils import (
36+ check_and_transform_label_format ,
37+ float_to_categorical ,
38+ floats_to_one_hot ,
39+ get_feature_values ,
40+ get_feature_index ,
41+ )
3642
3743if TYPE_CHECKING :
3844 from art .utils import CLASSIFIER_TYPE , REGRESSOR_TYPE
@@ -83,10 +89,6 @@ def __init__(
8389 `estimator` is a regressor and if `scale_range` is not supplied.
8490 """
8591 super ().__init__ (estimator = estimator , attack_feature = attack_feature )
86- if isinstance (self .attack_feature , int ):
87- self .single_index_feature = True
88- else :
89- self .single_index_feature = False
9092
9193 self ._values : Optional [list ] = None
9294 self ._attack_model_type = attack_model_type
@@ -131,6 +133,7 @@ def __init__(
131133 self .scale_range = scale_range
132134
133135 self ._check_params ()
136+ self .attack_feature = get_feature_index (self .attack_feature )
134137
135138 def fit (self , x : np .ndarray , y : Optional [np .ndarray ] = None ) -> None :
136139 """
@@ -144,7 +147,7 @@ def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> None:
144147 if self .estimator .input_shape is not None :
145148 if self .estimator .input_shape [0 ] != x .shape [1 ]:
146149 raise ValueError ("Shape of x does not match input_shape of model" )
147- if self . single_index_feature and isinstance (self .attack_feature , int ) and self .attack_feature >= x .shape [1 ]:
150+ if isinstance (self .attack_feature , int ) and self .attack_feature >= x .shape [1 ]:
148151 raise ValueError ("`attack_feature` must be a valid index to a feature in x" )
149152
150153 # get model's predictions for x
@@ -162,8 +165,8 @@ def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> None:
162165
163166 # get vector of attacked feature
164167 y_attack = x [:, self .attack_feature ]
165- self ._values = get_feature_values (y_attack , self .single_index_feature )
166- if self .single_index_feature :
168+ self ._values = get_feature_values (y_attack , isinstance ( self .attack_feature , int ) )
169+ if isinstance ( self .attack_feature , int ) :
167170 y_one_hot = float_to_categorical (y_attack )
168171 else :
169172 y_one_hot = floats_to_one_hot (y_attack )
@@ -210,7 +213,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
210213 if pred .shape [0 ] != x .shape [0 ]:
211214 raise ValueError ("Number of rows in x and y do not match" )
212215 if self .estimator .input_shape is not None :
213- if self .single_index_feature and self .estimator .input_shape [0 ] != x .shape [1 ] + 1 :
216+ if isinstance ( self .attack_feature , int ) and self .estimator .input_shape [0 ] != x .shape [1 ] + 1 :
214217 raise ValueError ("Number of features in x + 1 does not match input_shape of model" )
215218
216219 if RegressorMixin in type (self .estimator ).__mro__ :
@@ -234,7 +237,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
234237 predictions = self .attack_model .predict (x_test ).astype (np .float32 )
235238
236239 if self ._values is not None :
237- if self .single_index_feature :
240+ if isinstance ( self .attack_feature , int ) :
238241 predictions = np .array ([self ._values [np .argmax (arr )] for arr in predictions ])
239242 else :
240243 i = 0
0 commit comments