Skip to content

Commit c7ddc77

Browse files
committed
Adding changes related to style and types
Signed-off-by: monshri <[email protected]> Signed-off-by: Shriti Priya <[email protected]>
1 parent f94b080 commit c7ddc77

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

art/attacks/poisoning/sleeper_agent_attack.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def __init__(
7373
7474
:param classifier: The proxy classifier used for the attack.
7575
:param percent_poison: The ratio of samples to poison among x_train, with range [0,1].
76+
:patch: The patch to be applied as trigger.
77+
:indices_target: The indices of training data having target label.
7678
:param epsilon: The L-inf perturbation budget.
7779
:param max_trials: The maximum number of restarts to optimize the poison.
7880
:param max_epochs: The maximum number of epochs to optimize the train per trial.
@@ -82,15 +84,15 @@ def __init__(
8284
:param batch_size: Batch size.
8385
:param clip_values: The range of the input features to the classifier.
8486
:param verbose: Show progress bars.
85-
:indices_target: The indices of training data having target label.
8687
:patching_strategy: Patching strategy to be used for adding trigger, either random/fixed.
8788
:selection_strategy: Selection strategy for getting the indices of
8889
poison examples - either random/maximum gradient norm.
8990
:retraining_factor: The factor for which retraining needs to be applied.
9091
:model_retrain: True, if retraining has to be applied, else False.
9192
:model_retraining_epoch: The epochs for which retraining has to be applied.
92-
:patch: The patch to be applied as trigger.
93-
:K: Number of training samples belonging to target class selected for poisoning.
93+
:class_source: Source class from which triggers are selected.
94+
:class_target: The target class to be misclassified after poisoning.
95+
9496
"""
9597
super().__init__(
9698
classifier,
@@ -152,7 +154,7 @@ def poison( # type: ignore
152154
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
153155

154156
# Choose samples to poison.
155-
x_trigger = self.apply_trigger_patch(x_trigger)
157+
x_trigger = self.apply_trigger_patch(x_trigger)
156158
if len(np.shape(y_trigger)) == 2: # dense labels
157159
classes_target = set(np.argmax(y_trigger, axis=-1))
158160
else: # sparse labels
@@ -410,7 +412,6 @@ def apply_trigger_patch(self, x_trigger: np.ndarray) -> Union[np.ndarray, "torch
410412
:x_trigger: Samples to be used for trigger.
411413
:return tensor with applied trigger patches.
412414
"""
413-
from art.estimators.classification.pytorch import PyTorchClassifier
414415

415416
patch_size = self.patch.shape[1]
416417
if self.patching_strategy == "fixed":
@@ -420,9 +421,8 @@ def apply_trigger_patch(self, x_trigger: np.ndarray) -> Union[np.ndarray, "torch
420421
x_cord = random.randrange(0, x.shape[1] - self.patch.shape[1] + 1)
421422
y_cord = random.randrange(0, x.shape[2] - self.patch.shape[2] + 1)
422423
x[x_cord : x_cord + patch_size, y_cord : y_cord + patch_size, :] = self.patch
423-
if isinstance(self.substitute_classifier, PyTorchClassifier):
424-
import torch
425424

426-
return torch.tensor(np.transpose(x_trigger, [0, 3, 1, 2]))
425+
if self.estimator.channels_first:
426+
return np.transpose(x_trigger, [0, 3, 1, 2])
427427

428-
return np.transpose(x_trigger, [0, 3, 1, 2])
428+
return x_trigger

0 commit comments

Comments
 (0)