We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3fcbad0 commit 5d0c872Copy full SHA for 5d0c872
tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py
@@ -34,7 +34,7 @@ def _get_adv_trainer():
34
if framework in ["tensorflow", "tensorflow2v1"]:
35
trainer = None
36
if framework == "pytorch":
37
- classifier, _ = image_dl_estimator()
+ classifier, _ = image_dl_estimator(from_logits=True)
38
attack = ProjectedGradientDescent(
39
classifier,
40
norm=np.inf,
@@ -121,7 +121,6 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
121
assert y_test_mnist.shape[-1] == 10
122
if label_format == "numerical":
123
y_test_mnist = np.argmax(y_test_mnist, axis=1)
124
- y_train_mnist = np.argmax(y_train_mnist, axis=1)
125
126
generator = image_data_generator()
127
0 commit comments