Skip to content

Commit 3fcbad0

Browse files
committed
adding label check to trades adversarial trainer
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 82f8fa2 commit 3fcbad0

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

art/defences/trainer/adversarial_trainer_trades_pytorch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,13 @@ def fit(
126126
# compute accuracy
127127
if validation_data is not None:
128128
(x_test, y_test) = validation_data
129+
129130
output = np.argmax(self.predict(x_test), axis=1)
130-
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
131+
if y_test.ndim > 1:
132+
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
133+
else:
134+
nb_correct_pred = np.sum(output == y_test)
135+
131136
logger.info(
132137
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
133138
i_epoch,
@@ -240,7 +245,7 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
240245
)
241246

242247
# Check label shape
243-
if self._classifier._reduce_labels: # pylint: disable=W0212
248+
if self._classifier._reduce_labels and y_preprocessed.ndim > 1: # pylint: disable=W0212
244249
y_preprocessed = np.argmax(y_preprocessed, axis=1)
245250

246251
i_batch = torch.from_numpy(x_preprocessed).to(self._classifier._device) # pylint: disable=W0212

tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
6363
yield x_train_mnist[:n_train], y_train_mnist[:n_train], x_test_mnist[:n_test], y_test_mnist[:n_test]
6464

6565

66-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
67-
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset):
66+
@pytest.mark.only_with_platform("pytorch")
67+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
68+
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset, label_format):
6869
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
6970
x_test_mnist_original = x_test_mnist.copy()
7071

72+
if label_format == "one_hot":
73+
assert y_train_mnist.shape[-1] == 10
74+
assert y_test_mnist.shape[-1] == 10
75+
if label_format == "numerical":
76+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
77+
y_train_mnist = np.argmax(y_train_mnist, axis=1)
78+
7179
trainer = get_adv_trainer()
7280
if trainer is None:
7381
logging.warning("Couldn't perform this test because no trainer is defined for this framework configuration")
7482
return
7583

7684
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
77-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
85+
86+
if label_format == "one_hot":
87+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
88+
else:
89+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
7890

7991
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
8092
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
81-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
93+
94+
if label_format == "one_hot":
95+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
96+
else:
97+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
8298

8399
np.testing.assert_array_almost_equal(
84100
float(np.mean(x_test_mnist_original - x_test_mnist)),
@@ -92,13 +108,21 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
92108
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
93109

94110

95-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
111+
@pytest.mark.only_with_platform("pytorch")
112+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
96113
def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
97-
get_adv_trainer, fix_get_mnist_subset, image_data_generator
114+
get_adv_trainer, fix_get_mnist_subset, image_data_generator, label_format
98115
):
99116
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
100117
x_test_mnist_original = x_test_mnist.copy()
101118

119+
if label_format == "one_hot":
120+
assert y_train_mnist.shape[-1] == 10
121+
assert y_test_mnist.shape[-1] == 10
122+
if label_format == "numerical":
123+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
124+
y_train_mnist = np.argmax(y_train_mnist, axis=1)
125+
102126
generator = image_data_generator()
103127

104128
trainer = get_adv_trainer()
@@ -107,11 +131,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
107131
return
108132

109133
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
110-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
134+
if label_format == "one_hot":
135+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
136+
else:
137+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
111138

112139
trainer.fit_generator(generator=generator, nb_epochs=20)
113140
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
114-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
141+
142+
if label_format == "one_hot":
143+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
144+
else:
145+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
115146

116147
np.testing.assert_array_almost_equal(
117148
float(np.mean(x_test_mnist_original - x_test_mnist)),

0 commit comments

Comments
 (0)