2323from __future__ import absolute_import , division , print_function , unicode_literals
2424
2525import logging
26- from typing import Any , Tuple , TYPE_CHECKING , Union , List
26+ from typing import Any , Tuple , TYPE_CHECKING , List
2727import random
2828
2929import numpy as np
@@ -73,7 +73,6 @@ def __init__(
7373
7474 :param classifier: The proxy classifier used for the attack.
7575 :param percent_poison: The ratio of samples to poison among x_train, with range [0,1].
76- :patch: The patch to be applied as trigger.
7776 :indices_target: The indices of training data having target label.
7877 :param epsilon: The L-inf perturbation budget.
7978 :param max_trials: The maximum number of restarts to optimize the poison.
@@ -90,9 +89,8 @@ def __init__(
9089 :retraining_factor: The factor for which retraining needs to be applied.
9190 :model_retrain: True, if retraining has to be applied, else False.
9291 :model_retraining_epoch: The epochs for which retraining has to be applied.
93- :class_source: Source class from which triggers are selected.
94- :class_target: The target class to be misclassified after poisoning.
95-
92+ :patch: The patch to be applied as trigger.
93+ :K: Number of training samples belonging to target class selected for poisoning.
9694 """
9795 super ().__init__ (
9896 classifier ,
@@ -111,7 +109,7 @@ def __init__(
111109 self .retraining_factor = retraining_factor
112110 self .model_retrain = model_retrain
113111 self .model_retraining_epoch = model_retraining_epoch
114- self .indices_poison = List [ int ]
112+ self .indices_poison : np . ndarray
115113 self .patch = patch
116114 self .class_target = class_target
117115 self .class_source = class_source
@@ -147,9 +145,8 @@ def poison( # type: ignore
147145 poisoner = self ._poison__pytorch
148146 finish_poisoning = self ._finish_poison_pytorch
149147 initializer = self ._initialize_poison_pytorch
150- x_train_target_samples = torch .tensor (
151- np .transpose (x_train_target_samples , [0 , 3 , 1 , 2 ]), dtype = torch .float32
152- ) # type: ignore
148+ if self .estimator .channels_first :
149+ x_train_target_samples = np .transpose (x_train_target_samples , [0 , 3 , 1 , 2 ])
153150 else :
154151 raise NotImplementedError ("SleeperAgentAttack is currently implemented only for PyTorch." )
155152
@@ -163,8 +160,8 @@ def poison( # type: ignore
163160
164161 # Try poisoning num_trials times and choose the best one.
165162 best_B = np .finfo (np .float32 ).max # pylint: disable=C0103
166- best_x_poisoned = None
167- best_indices_poison = None
163+ best_x_poisoned : np . ndarray
164+ best_indices_poison : np . ndarray
168165
169166 if len (np .shape (y_train )) == 2 :
170167 y_train_classes = np .argmax (y_train_target_samples , axis = - 1 )
@@ -178,7 +175,7 @@ def poison( # type: ignore
178175 else :
179176 self .indices_poison = self .select_poison_indices (
180177 self .substitute_classifier , x_train_target_samples , y_train_target_samples , num_poison_samples
181- ) # type: ignore
178+ )
182179 x_poison = x_train_target_samples [self .indices_poison ]
183180 y_poison = y_train_target_samples [self .indices_poison ]
184181 initializer (x_trigger , y_trigger , x_poison , y_poison )
@@ -221,7 +218,7 @@ def select_target_train_samples(self, x_train: np.ndarray, y_train: np.ndarray)
221218 y_train_target_samples = y_train [index_target ]
222219 return x_train_target_samples , y_train_target_samples
223220
224- def get_poison_indices (self ) -> List [ int ] :
221+ def get_poison_indices (self ) -> np . ndarray :
225222 """
226223 :return: indices of best poison index
227224 """
@@ -248,7 +245,6 @@ def model_retraining(
248245
249246 x_train = np .transpose (x_train , [0 , 3 , 1 , 2 ])
250247
251- poisoned_samples = np .asarray (poisoned_samples )
252248 x_train [self .indices_target [self .indices_poison ]] = poisoned_samples
253249 model , loss_fn , optimizer = self .create_model (
254250 x_train ,
@@ -288,7 +284,7 @@ def create_model(
288284 :param epochs: The number of epochs for which training need to be applied.
289285 :return model, loss_fn, optimizer - trained model, loss function used to train the model and optimizer used.
290286 """
291- import torch # lgtm [py/repeated-import]
287+ import torch
292288 from torch import nn
293289 from torch .utils .data import TensorDataset , DataLoader
294290 import torchvision
@@ -370,7 +366,7 @@ def test_accuracy(cls, model: "torch.nn.Module", test_loader: "torch.utils.data.
370366 @classmethod
371367 def select_poison_indices (
372368 cls , classifier : "CLASSIFIER_NEURALNETWORK_TYPE" , x_samples : np .ndarray , y_samples : np .ndarray , num_poison : int
373- ) -> List [ int ] :
369+ ) -> np . ndarray :
374370 """
375371 Select indices of poisoned samples
376372
@@ -380,7 +376,7 @@ def select_poison_indices(
380376 :num_poison: Number of poisoned samples to be selected out of all x_samples.
381377 :return indices - Indices of samples to be poisoned.
382378 """
383- import torch # lgtm [py/repeated-import]
379+ import torch
384380
385381 device = "cuda" if torch .cuda .is_available () else "cpu"
386382 grad_norms = []
@@ -393,26 +389,25 @@ def select_poison_indices(
393389 label = torch .tensor (y ).to (device )
394390 loss = criterion (model (image .unsqueeze (0 )), label .unsqueeze (0 ))
395391 gradients = torch .autograd .grad (loss , differentiable_params , only_inputs = True )
396- grad_norm = 0
392+ grad_norm = torch . tensor ( 0 )
397393 for grad in gradients :
398394 grad_norm += grad .detach ().pow (2 ).sum ()
399395 grad_norms .append (grad_norm .sqrt ())
400396
401397 indices = sorted (range (len (grad_norms )), key = lambda k : grad_norms [k ])
402398 indices = indices [- num_poison :]
403- return indices # this will get only indices for target class
399+ return np . array ( indices ) # this will get only indices for target class
404400
405401 # This function is responsible for applying trigger patches to the images
406402 # fixed - where the trigger is applied at the bottom right of the image
407403 # random - where the trigger is applied at random location of the image
408- def apply_trigger_patch (self , x_trigger : np .ndarray ) -> Union [ np .ndarray , "torch.Tensor" ] :
404+ def apply_trigger_patch (self , x_trigger : np .ndarray ) -> np .ndarray :
409405 """
410406 Select indices of poisoned samples
411407
412408 :x_trigger: Samples to be used for trigger.
413409 :return tensor with applied trigger patches.
414410 """
415-
416411 patch_size = self .patch .shape [1 ]
417412 if self .patching_strategy == "fixed" :
418413 x_trigger [:, - patch_size :, - patch_size :, :] = self .patch
0 commit comments