3333from art .estimators .classification .pytorch import PyTorchClassifier
3434from art .data_generators import DataGenerator
3535from art .attacks .attack import EvasionAttack
36+ from art .utils import check_and_transform_label_format
3637
3738if TYPE_CHECKING :
3839 import torch
@@ -97,6 +98,15 @@ def fit(
9798 ind = np .arange (len (x ))
9899
99100 logger .info ("Adversarial Training TRADES" )
101+ y = check_and_transform_label_format (y , nb_classes = self .classifier .nb_classes )
102+
103+ if validation_data is not None :
104+ (x_test , y_test ) = validation_data
105+ y_test = check_and_transform_label_format (y_test , nb_classes = self .classifier .nb_classes )
106+
107+ x_preprocessed_test , y_preprocessed_test = self ._classifier ._apply_preprocessing ( # pylint: disable=W0212
108+ x_test , y_test , fit = True
109+ )
100110
101111 for i_epoch in trange (nb_epochs , desc = "Adversarial Training TRADES - Epochs" ):
102112 # Shuffle the examples
@@ -107,7 +117,6 @@ def fit(
107117 train_n = 0.0
108118
109119 for batch_id in range (nb_batches ):
110-
111120 # Create batch data
112121 x_batch = x [ind [batch_id * batch_size : min ((batch_id + 1 ) * batch_size , x .shape [0 ])]].copy ()
113122 y_batch = y [ind [batch_id * batch_size : min ((batch_id + 1 ) * batch_size , x .shape [0 ])]]
@@ -125,13 +134,8 @@ def fit(
125134
126135 # compute accuracy
127136 if validation_data is not None :
128- (x_test , y_test ) = validation_data
129-
130- output = np .argmax (self .predict (x_test ), axis = 1 )
131- if y_test .ndim > 1 :
132- nb_correct_pred = np .sum (output == np .argmax (y_test , axis = 1 ))
133- else :
134- nb_correct_pred = np .sum (output == y_test )
137+ output = np .argmax (self .predict (x_preprocessed_test ), axis = 1 )
138+ nb_correct_pred = np .sum (output == np .argmax (y_preprocessed_test , axis = 1 ))
135139
136140 logger .info (
137141 "epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f" ,
@@ -193,7 +197,6 @@ def fit_generator(
193197 train_n = 0.0
194198
195199 for batch_id in range (nb_batches ): # pylint: disable=W0612
196-
197200 # Create batch data
198201 x_batch , y_batch = generator .get_batch ()
199202 x_batch = x_batch .copy ()
@@ -237,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
237240 x_batch_pert = self ._attack .generate (x_batch , y = y_batch )
238241
239242 # Apply preprocessing
243+ y_batch = check_and_transform_label_format (y_batch , nb_classes = self .classifier .nb_classes )
244+
240245 x_preprocessed , y_preprocessed = self ._classifier ._apply_preprocessing ( # pylint: disable=W0212
241246 x_batch , y_batch , fit = True
242247 )
@@ -245,7 +250,7 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
245250 )
246251
247252 # Check label shape
248- if self ._classifier ._reduce_labels and y_preprocessed . ndim > 1 : # pylint: disable=W0212
253+ if self ._classifier ._reduce_labels : # pylint: disable=W0212
249254 y_preprocessed = np .argmax (y_preprocessed , axis = 1 )
250255
251256 i_batch = torch .from_numpy (x_preprocessed ).to (self ._classifier ._device ) # pylint: disable=W0212
0 commit comments