Skip to content

Commit 7b10233

Browse files
author
Beat Buesser
committed
Add targeted version of DPatch
Signed-off-by: Beat Buesser <[email protected]>
1 parent fdd788a commit 7b10233

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

art/attacks/evasion/dpatch.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import math
2525
import random
26-
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
26+
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
2727

2828
import numpy as np
2929
from tqdm import trange
@@ -81,12 +81,21 @@ def __init__(
8181
self._patch = np.ones(shape=patch_shape) * (self.estimator.clip_values[1] + self.estimator.clip_values[0]) / 2.0
8282
self._check_params()
8383

84-
def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
84+
self.target_label = []
85+
86+
def generate(
87+
self,
88+
x: np.ndarray,
89+
y: Optional[np.ndarray] = None,
90+
target_label: Union[int, List[int], np.ndarray] = 0,
91+
**kwargs
92+
) -> np.ndarray:
8593
"""
8694
Generate DPatch.
8795
8896
:param x: Sample images.
8997
:param y: Target labels for object detector.
98+
:param target_label: The target label of the DPatch attack.
9099
:return: Adversarial patch.
91100
"""
92101
channel_index = 1 if self.estimator.channels_first else x.ndim - 1
@@ -96,6 +105,16 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
96105
raise ValueError("The DPatch attack does not use target labels.")
97106
if x.ndim != 4:
98107
raise ValueError("The adversarial patch can only be applied to images.")
108+
if isinstance(target_label, int):
109+
self.target_label = [target_label] * x.shape[0]
110+
elif isinstance(target_label, np.ndarray):
111+
if not (target_label.shape == (x.shape[0], 1) or target_label.shape == (x.shape[0],)):
112+
raise ValueError("The target_label has to be a 1-dimensional array.")
113+
self.target_label = target_label.tolist()
114+
else:
115+
if not len(target_label == x.shape[0]) or not isinstance(target_label, list):
116+
raise ValueError("The target_label as list of integers needs to of length number of images in `x`.")
117+
self.target_label = target_label
99118

100119
for i_step in trange(self.max_iter, desc="DPatch iteration"):
101120
if i_step == 0 or (i_step + 1) % 100 == 0:
@@ -115,7 +134,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
115134

116135
target_dict = dict()
117136
target_dict["boxes"] = np.asarray([[i_x_1, i_y_1, i_x_2, i_y_2]])
118-
target_dict["labels"] = np.asarray([1,])
137+
target_dict["labels"] = np.asarray([target_label[i_image],])
119138
target_dict["scores"] = np.asarray([1.0,])
120139

121140
patch_target.append(target_dict)

0 commit comments

Comments
 (0)