@@ -73,6 +73,8 @@ def __init__(
7373
7474 :param classifier: The proxy classifier used for the attack.
7575 :param percent_poison: The ratio of samples to poison among x_train, with range [0,1].
76+ :patch: The patch to be applied as trigger.
77+ :indices_target: The indices of training data having target label.
7678 :param epsilon: The L-inf perturbation budget.
7779 :param max_trials: The maximum number of restarts to optimize the poison.
7880 :param max_epochs: The maximum number of epochs to optimize the train per trial.
@@ -82,15 +84,15 @@ def __init__(
8284 :param batch_size: Batch size.
8385 :param clip_values: The range of the input features to the classifier.
8486 :param verbose: Show progress bars.
85- :indices_target: The indices of training data having target label.
8687 :patching_strategy: Patching strategy to be used for adding trigger, either random/fixed.
8788 :selection_strategy: Selection strategy for getting the indices of
8889 poison examples - either random/maximum gradient norm.
8990 :retraining_factor: The factor for which retraining needs to be applied.
9091 :model_retrain: True, if retraining has to be applied, else False.
9192 :model_retraining_epoch: The epochs for which retraining has to be applied.
92- :patch: The patch to be applied as trigger.
93- :K: Number of training samples belonging to target class selected for poisoning.
93+ :class_source: Source class from which triggers are selected.
94+ :class_target: The target class to be misclassified after poisoning.
95+
9496 """
9597 super ().__init__ (
9698 classifier ,
@@ -152,7 +154,7 @@ def poison( # type: ignore
152154 raise NotImplementedError ("SleeperAgentAttack is currently implemented only for PyTorch." )
153155
154156 # Choose samples to poison.
155- x_trigger = self .apply_trigger_patch (x_trigger )
157+ x_trigger = self .apply_trigger_patch (x_trigger )
156158 if len (np .shape (y_trigger )) == 2 : # dense labels
157159 classes_target = set (np .argmax (y_trigger , axis = - 1 ))
158160 else : # sparse labels
@@ -410,7 +412,6 @@ def apply_trigger_patch(self, x_trigger: np.ndarray) -> Union[np.ndarray, "torch
410412 :x_trigger: Samples to be used for trigger.
411413 :return tensor with applied trigger patches.
412414 """
413- from art .estimators .classification .pytorch import PyTorchClassifier
414415
415416 patch_size = self .patch .shape [1 ]
416417 if self .patching_strategy == "fixed" :
@@ -420,9 +421,8 @@ def apply_trigger_patch(self, x_trigger: np.ndarray) -> Union[np.ndarray, "torch
420421 x_cord = random .randrange (0 , x .shape [1 ] - self .patch .shape [1 ] + 1 )
421422 y_cord = random .randrange (0 , x .shape [2 ] - self .patch .shape [2 ] + 1 )
422423 x [x_cord : x_cord + patch_size , y_cord : y_cord + patch_size , :] = self .patch
423- if isinstance (self .substitute_classifier , PyTorchClassifier ):
424- import torch
425424
426- return torch .tensor (np .transpose (x_trigger , [0 , 3 , 1 , 2 ]))
425+ if self .estimator .channels_first :
426+ return np .transpose (x_trigger , [0 , 3 , 1 , 2 ])
427427
428- return np . transpose ( x_trigger , [ 0 , 3 , 1 , 2 ])
428+ return x_trigger
0 commit comments