Skip to content

Commit c67d59f

Browse files
authored
Merge pull request #1769 from monshri/dev_sleeper_agent_tf
Implement Sleeper Agent Poisoning Attack in TensorFlow
2 parents 8641de9 + 897c4da commit c67d59f

File tree

5 files changed

+1430
-2007
lines changed

5 files changed

+1430
-2007
lines changed

art/attacks/poisoning/gradient_matching_attack.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
self.max_epochs = max_epochs
9898
self.batch_size = batch_size
9999
self.clip_values = clip_values
100+
self.initial_epoch = 0
100101

101102
if verbose is True:
102103
verbose = 1
@@ -175,7 +176,7 @@ def _weight_grad(classifier: TensorFlowV2Classifier, x: tf.Tensor, target: tf.Te
175176
with tf.GradientTape() as t: # pylint: disable=C0103
176177
t.watch(classifier.model.weights)
177178
output = classifier.model(x, training=False)
178-
loss = classifier.model.compiled_loss(target, output)
179+
loss = classifier.loss_object(target, output)
179180
d_w = t.gradient(loss, classifier.model.weights)
180181
d_w = [w for w in d_w if w is not None]
181182
d_w = tf.concat([tf.reshape(d, [-1]) for d in d_w], 0)
@@ -478,7 +479,11 @@ def __len__(self):
478479
PoisonDataset(x_poison, y_poison), batch_size=self.batch_size, shuffle=False, num_workers=1
479480
)
480481

481-
epoch_iterator = trange(self.max_epochs) if self.verbose > 0 else range(self.max_epochs)
482+
epoch_iterator = (
483+
trange(self.initial_epoch, self.max_epochs)
484+
if self.verbose > 0
485+
else range(self.initial_epoch, self.max_epochs)
486+
)
482487
for _ in epoch_iterator:
483488
batch_iterator = tqdm(trainloader) if isinstance(self.verbose, int) and self.verbose >= 2 else trainloader
484489
sum_loss = 0
@@ -536,6 +541,7 @@ def _poison__tensorflow(self, x_poison: np.ndarray, y_poison: np.ndarray) -> Tup
536541
[x_poison, y_poison, np.arange(len(y_poison))],
537542
callbacks=callbacks,
538543
batch_size=self.batch_size,
544+
initial_epoch=self.initial_epoch,
539545
epochs=self.max_epochs,
540546
verbose=0,
541547
)

0 commit comments

Comments
 (0)