Skip to content

Commit 52c240a

Browse files
authored
Merge pull request #2231 from GiulioZizzo/update_trades
Adding label check to trades adversarial trainer
2 parents c63d5d5 + d43f473 commit 52c240a

File tree

2 files changed

+54
-14
lines changed

2 files changed

+54
-14
lines changed

art/defences/trainer/adversarial_trainer_trades_pytorch.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from art.estimators.classification.pytorch import PyTorchClassifier
3434
from art.data_generators import DataGenerator
3535
from art.attacks.attack import EvasionAttack
36+
from art.utils import check_and_transform_label_format
3637

3738
if TYPE_CHECKING:
3839
import torch
@@ -97,6 +98,15 @@ def fit(
9798
ind = np.arange(len(x))
9899

99100
logger.info("Adversarial Training TRADES")
101+
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)
102+
103+
if validation_data is not None:
104+
(x_test, y_test) = validation_data
105+
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)
106+
107+
x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing( # pylint: disable=W0212
108+
x_test, y_test, fit=True
109+
)
100110

101111
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
102112
# Shuffle the examples
@@ -107,7 +117,6 @@ def fit(
107117
train_n = 0.0
108118

109119
for batch_id in range(nb_batches):
110-
111120
# Create batch data
112121
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
113122
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
@@ -125,9 +134,9 @@ def fit(
125134

126135
# compute accuracy
127136
if validation_data is not None:
128-
(x_test, y_test) = validation_data
129-
output = np.argmax(self.predict(x_test), axis=1)
130-
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
137+
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
138+
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))
139+
131140
logger.info(
132141
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
133142
i_epoch,
@@ -188,7 +197,6 @@ def fit_generator(
188197
train_n = 0.0
189198

190199
for batch_id in range(nb_batches): # pylint: disable=W0612
191-
192200
# Create batch data
193201
x_batch, y_batch = generator.get_batch()
194202
x_batch = x_batch.copy()
@@ -232,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
232240
x_batch_pert = self._attack.generate(x_batch, y=y_batch)
233241

234242
# Apply preprocessing
243+
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)
244+
235245
x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
236246
x_batch, y_batch, fit=True
237247
)

tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_adv_trainer():
3434
if framework in ["tensorflow", "tensorflow2v1"]:
3535
trainer = None
3636
if framework == "pytorch":
37-
classifier, _ = image_dl_estimator()
37+
classifier, _ = image_dl_estimator(from_logits=True)
3838
attack = ProjectedGradientDescent(
3939
classifier,
4040
norm=np.inf,
@@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
6363
yield x_train_mnist[:n_train], y_train_mnist[:n_train], x_test_mnist[:n_test], y_test_mnist[:n_test]
6464

6565

66-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
67-
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset):
66+
@pytest.mark.only_with_platform("pytorch")
67+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
68+
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset, label_format):
6869
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
6970
x_test_mnist_original = x_test_mnist.copy()
7071

72+
if label_format == "one_hot":
73+
assert y_train_mnist.shape[-1] == 10
74+
assert y_test_mnist.shape[-1] == 10
75+
if label_format == "numerical":
76+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
77+
y_train_mnist = np.argmax(y_train_mnist, axis=1)
78+
7179
trainer = get_adv_trainer()
7280
if trainer is None:
7381
logging.warning("Couldn't perform this test because no trainer is defined for this framework configuration")
7482
return
7583

7684
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
77-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
85+
86+
if label_format == "one_hot":
87+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
88+
else:
89+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
7890

7991
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
8092
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
81-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
93+
94+
if label_format == "one_hot":
95+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
96+
else:
97+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
8298

8399
np.testing.assert_array_almost_equal(
84100
float(np.mean(x_test_mnist_original - x_test_mnist)),
@@ -92,13 +108,20 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
92108
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
93109

94110

95-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
111+
@pytest.mark.only_with_platform("pytorch")
112+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
96113
def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
97-
get_adv_trainer, fix_get_mnist_subset, image_data_generator
114+
get_adv_trainer, fix_get_mnist_subset, image_data_generator, label_format
98115
):
99116
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
100117
x_test_mnist_original = x_test_mnist.copy()
101118

119+
if label_format == "one_hot":
120+
assert y_train_mnist.shape[-1] == 10
121+
assert y_test_mnist.shape[-1] == 10
122+
if label_format == "numerical":
123+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
124+
102125
generator = image_data_generator()
103126

104127
trainer = get_adv_trainer()
@@ -107,11 +130,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
107130
return
108131

109132
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
110-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
133+
if label_format == "one_hot":
134+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
135+
else:
136+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
111137

112138
trainer.fit_generator(generator=generator, nb_epochs=20)
113139
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
114-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
140+
141+
if label_format == "one_hot":
142+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
143+
else:
144+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
115145

116146
np.testing.assert_array_almost_equal(
117147
float(np.mean(x_test_mnist_original - x_test_mnist)),

0 commit comments

Comments
 (0)