Skip to content

Commit 2338a08

Browse files
committed
use check_and_transform_label_format to assure consistant label format
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 5d0c872 commit 2338a08

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

art/defences/trainer/adversarial_trainer_trades_pytorch.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from art.estimators.classification.pytorch import PyTorchClassifier
3434
from art.data_generators import DataGenerator
3535
from art.attacks.attack import EvasionAttack
36+
from art.utils import check_and_transform_label_format
3637

3738
if TYPE_CHECKING:
3839
import torch
@@ -97,6 +98,15 @@ def fit(
9798
ind = np.arange(len(x))
9899

99100
logger.info("Adversarial Training TRADES")
101+
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)
102+
103+
if validation_data is not None:
104+
(x_test, y_test) = validation_data
105+
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)
106+
107+
x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing( # pylint: disable=W0212
108+
x_test, y_test, fit=True
109+
)
100110

101111
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
102112
# Shuffle the examples
@@ -107,7 +117,6 @@ def fit(
107117
train_n = 0.0
108118

109119
for batch_id in range(nb_batches):
110-
111120
# Create batch data
112121
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
113122
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
@@ -125,13 +134,8 @@ def fit(
125134

126135
# compute accuracy
127136
if validation_data is not None:
128-
(x_test, y_test) = validation_data
129-
130-
output = np.argmax(self.predict(x_test), axis=1)
131-
if y_test.ndim > 1:
132-
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
133-
else:
134-
nb_correct_pred = np.sum(output == y_test)
137+
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
138+
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))
135139

136140
logger.info(
137141
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
@@ -193,7 +197,6 @@ def fit_generator(
193197
train_n = 0.0
194198

195199
for batch_id in range(nb_batches): # pylint: disable=W0612
196-
197200
# Create batch data
198201
x_batch, y_batch = generator.get_batch()
199202
x_batch = x_batch.copy()
@@ -237,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
237240
x_batch_pert = self._attack.generate(x_batch, y=y_batch)
238241

239242
# Apply preprocessing
243+
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)
244+
240245
x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
241246
x_batch, y_batch, fit=True
242247
)
@@ -245,7 +250,7 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
245250
)
246251

247252
# Check label shape
248-
if self._classifier._reduce_labels and y_preprocessed.ndim > 1: # pylint: disable=W0212
253+
if self._classifier._reduce_labels: # pylint: disable=W0212
249254
y_preprocessed = np.argmax(y_preprocessed, axis=1)
250255

251256
i_batch = torch.from_numpy(x_preprocessed).to(self._classifier._device) # pylint: disable=W0212

0 commit comments

Comments
 (0)