Skip to content

Commit 5d0c872

Browse files
committed
add from_logits=True to begin addressing issue #2227
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 3fcbad0 commit 5d0c872

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_adv_trainer():
3434
if framework in ["tensorflow", "tensorflow2v1"]:
3535
trainer = None
3636
if framework == "pytorch":
37-
classifier, _ = image_dl_estimator()
37+
classifier, _ = image_dl_estimator(from_logits=True)
3838
attack = ProjectedGradientDescent(
3939
classifier,
4040
norm=np.inf,
@@ -121,7 +121,6 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
121121
assert y_test_mnist.shape[-1] == 10
122122
if label_format == "numerical":
123123
y_test_mnist = np.argmax(y_test_mnist, axis=1)
124-
y_train_mnist = np.argmax(y_train_mnist, axis=1)
125124

126125
generator = image_data_generator()
127126

0 commit comments

Comments
 (0)