Skip to content

Commit 78890a4

Browse files
committed
fixed discrepancies
1 parent bce965c commit 78890a4

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

art/defences/adversarial_trainer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
7575

7676
for batch_id in range(nb_batches):
7777
# Create batch data
78-
x_batch = x[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]]
78+
x_batch = x[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]].copy()
7979
y_batch = y[ind[batch_id * batch_size:min((batch_id + 1) * batch_size, x.shape[0])]]
8080

8181
# Choose indices to replace with adversarial samples
@@ -84,12 +84,7 @@ def fit(self, x, y, batch_size=128, nb_epochs=20):
8484

8585
# If source and target models are the same, craft fresh adversarial samples
8686
if attack.classifier == self.classifier:
87-
labels_batch = np.argmax(y_batch, axis=1)
88-
x_adv = attack.generate(x_batch[adv_ids])
89-
y_adv = np.argmax(attack.classifier.predict(x_adv), axis=1)
90-
selected = np.array(y_adv != labels_batch[adv_ids])
91-
92-
x_batch[adv_ids][selected] = x_adv[selected]
87+
x_batch[adv_ids] = attack.generate(x_batch[adv_ids])
9388

9489
# Otherwise, use precomputed adversarial samples
9590
else:

0 commit comments

Comments
 (0)