Skip to content

Commit c295862

Browse files
committed
Adding formatting changes
Signed-off-by: Shriti Priya <[email protected]>
1 parent c7ddc77 commit c295862

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

art/attacks/poisoning/sleeper_agent_attack.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Any, Tuple, TYPE_CHECKING, Union, List
26+
from typing import Any, Tuple, TYPE_CHECKING, List
2727
import random
2828

2929
import numpy as np
@@ -73,7 +73,6 @@ def __init__(
7373
7474
:param classifier: The proxy classifier used for the attack.
7575
:param percent_poison: The ratio of samples to poison among x_train, with range [0,1].
76-
:patch: The patch to be applied as trigger.
7776
:indices_target: The indices of training data having target label.
7877
:param epsilon: The L-inf perturbation budget.
7978
:param max_trials: The maximum number of restarts to optimize the poison.
@@ -90,9 +89,8 @@ def __init__(
9089
:retraining_factor: The factor for which retraining needs to be applied.
9190
:model_retrain: True, if retraining has to be applied, else False.
9291
:model_retraining_epoch: The epochs for which retraining has to be applied.
93-
:class_source: Source class from which triggers are selected.
94-
:class_target: The target class to be misclassified after poisoning.
95-
92+
:patch: The patch to be applied as trigger.
93+
:K: Number of training samples belonging to target class selected for poisoning.
9694
"""
9795
super().__init__(
9896
classifier,
@@ -111,7 +109,7 @@ def __init__(
111109
self.retraining_factor = retraining_factor
112110
self.model_retrain = model_retrain
113111
self.model_retraining_epoch = model_retraining_epoch
114-
self.indices_poison = List[int]
112+
self.indices_poison: np.ndarray
115113
self.patch = patch
116114
self.class_target = class_target
117115
self.class_source = class_source
@@ -147,9 +145,8 @@ def poison( # type: ignore
147145
poisoner = self._poison__pytorch
148146
finish_poisoning = self._finish_poison_pytorch
149147
initializer = self._initialize_poison_pytorch
150-
x_train_target_samples = torch.tensor(
151-
np.transpose(x_train_target_samples, [0, 3, 1, 2]), dtype=torch.float32
152-
) # type: ignore
148+
if self.estimator.channels_first:
149+
x_train_target_samples = np.transpose(x_train_target_samples, [0, 3, 1, 2])
153150
else:
154151
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
155152

@@ -163,8 +160,8 @@ def poison( # type: ignore
163160

164161
# Try poisoning num_trials times and choose the best one.
165162
best_B = np.finfo(np.float32).max # pylint: disable=C0103
166-
best_x_poisoned = None
167-
best_indices_poison = None
163+
best_x_poisoned: np.ndarray
164+
best_indices_poison: np.ndarray
168165

169166
if len(np.shape(y_train)) == 2:
170167
y_train_classes = np.argmax(y_train_target_samples, axis=-1)
@@ -178,7 +175,7 @@ def poison( # type: ignore
178175
else:
179176
self.indices_poison = self.select_poison_indices(
180177
self.substitute_classifier, x_train_target_samples, y_train_target_samples, num_poison_samples
181-
) # type: ignore
178+
)
182179
x_poison = x_train_target_samples[self.indices_poison]
183180
y_poison = y_train_target_samples[self.indices_poison]
184181
initializer(x_trigger, y_trigger, x_poison, y_poison)
@@ -221,7 +218,7 @@ def select_target_train_samples(self, x_train: np.ndarray, y_train: np.ndarray)
221218
y_train_target_samples = y_train[index_target]
222219
return x_train_target_samples, y_train_target_samples
223220

224-
def get_poison_indices(self) -> List[int]:
221+
def get_poison_indices(self) -> np.ndarray:
225222
"""
226223
:return: indices of best poison index
227224
"""
@@ -248,7 +245,6 @@ def model_retraining(
248245

249246
x_train = np.transpose(x_train, [0, 3, 1, 2])
250247

251-
poisoned_samples = np.asarray(poisoned_samples)
252248
x_train[self.indices_target[self.indices_poison]] = poisoned_samples
253249
model, loss_fn, optimizer = self.create_model(
254250
x_train,
@@ -288,7 +284,7 @@ def create_model(
288284
:param epochs: The number of epochs for which training need to be applied.
289285
:return model, loss_fn, optimizer - trained model, loss function used to train the model and optimizer used.
290286
"""
291-
import torch # lgtm [py/repeated-import]
287+
import torch
292288
from torch import nn
293289
from torch.utils.data import TensorDataset, DataLoader
294290
import torchvision
@@ -370,7 +366,7 @@ def test_accuracy(cls, model: "torch.nn.Module", test_loader: "torch.utils.data.
370366
@classmethod
371367
def select_poison_indices(
372368
cls, classifier: "CLASSIFIER_NEURALNETWORK_TYPE", x_samples: np.ndarray, y_samples: np.ndarray, num_poison: int
373-
) -> List[int]:
369+
) -> np.ndarray:
374370
"""
375371
Select indices of poisoned samples
376372
@@ -380,7 +376,7 @@ def select_poison_indices(
380376
:num_poison: Number of poisoned samples to be selected out of all x_samples.
381377
:return indices - Indices of samples to be poisoned.
382378
"""
383-
import torch # lgtm [py/repeated-import]
379+
import torch
384380

385381
device = "cuda" if torch.cuda.is_available() else "cpu"
386382
grad_norms = []
@@ -393,26 +389,25 @@ def select_poison_indices(
393389
label = torch.tensor(y).to(device)
394390
loss = criterion(model(image.unsqueeze(0)), label.unsqueeze(0))
395391
gradients = torch.autograd.grad(loss, differentiable_params, only_inputs=True)
396-
grad_norm = 0
392+
grad_norm = torch.tensor(0)
397393
for grad in gradients:
398394
grad_norm += grad.detach().pow(2).sum()
399395
grad_norms.append(grad_norm.sqrt())
400396

401397
indices = sorted(range(len(grad_norms)), key=lambda k: grad_norms[k])
402398
indices = indices[-num_poison:]
403-
return indices # this will get only indices for target class
399+
return np.array(indices) # this will get only indices for target class
404400

405401
# This function is responsible for applying trigger patches to the images
406402
# fixed - where the trigger is applied at the bottom right of the image
407403
# random - where the trigger is applied at random location of the image
408-
def apply_trigger_patch(self, x_trigger: np.ndarray) -> Union[np.ndarray, "torch.Tensor"]:
404+
def apply_trigger_patch(self, x_trigger: np.ndarray) -> np.ndarray:
409405
"""
410406
Select indices of poisoned samples
411407
412408
:x_trigger: Samples to be used for trigger.
413409
:return tensor with applied trigger patches.
414410
"""
415-
416411
patch_size = self.patch.shape[1]
417412
if self.patching_strategy == "fixed":
418413
x_trigger[:, -patch_size:, -patch_size:, :] = self.patch

0 commit comments

Comments
 (0)