Skip to content

Commit 6c7befd

Browse files
committed
Update object detection examples
Signed-off-by: Beat Buesser <[email protected]>
1 parent 7afb3b2 commit 6c7befd

File tree

2 files changed

+56
-50
lines changed

2 files changed

+56
-50
lines changed

examples/application_object_detection.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import matplotlib.pyplot as plt
44

55
from art.estimators.object_detection import PyTorchFasterRCNN
6-
from art.attacks.evasion import ProjectedGradientDescent
6+
from art.estimators.object_detection.pytorch_yolo import PyTorchYolo
7+
from art.attacks.evasion import ProjectedGradientDescent, AdversarialPatchPyTorch
78

89
COCO_INSTANCE_CATEGORY_NAMES = [
910
"__background__",
@@ -118,15 +119,18 @@ def extract_predictions(predictions_):
118119

119120
predictions_boxes = predictions_boxes[: predictions_t + 1]
120121
predictions_class = predictions_class[: predictions_t + 1]
122+
predictions_scores = predictions_score[: predictions_t + 1]
121123

122-
return predictions_class, predictions_boxes, predictions_class
124+
return predictions_class, predictions_boxes, predictions_scores
123125

124126

125127
def plot_image_with_boxes(img, boxes, pred_cls):
126128
text_size = 2
127129
text_th = 2
128130
rect_th = 2
129131

132+
img = img.copy()
133+
130134
for i in range(len(boxes)):
131135
# Draw Rectangle with the coordinates
132136

@@ -186,14 +190,14 @@ def main():
186190
print("\nPredictions image {}:".format(i))
187191

188192
# Process predictions
189-
predictions_class, predictions_boxes, predictions_class = extract_predictions(predictions[i])
193+
predictions_class, predictions_boxes, _ = extract_predictions(predictions[i])
190194

191195
# Plot predictions
192-
plot_image_with_boxes(img=image[i].copy(), boxes=predictions_boxes, pred_cls=predictions_class)
196+
plot_image_with_boxes(img=image[i], boxes=predictions_boxes, pred_cls=predictions_class)
193197

194198
# Create and run attack
195199
eps = 32
196-
attack = ProjectedGradientDescent(estimator=frcnn, eps=eps, eps_step=2, max_iter=10)
200+
attack = ProjectedGradientDescent(estimator=frcnn, eps=eps, eps_step=2, max_iter=2)
197201
image_adv_chw = attack.generate(x=image_chw, y=None)
198202
image_adv = np.transpose(image_adv_chw, (0, 2, 3, 1))
199203

@@ -212,10 +216,10 @@ def main():
212216
print("\nPredictions adversarial image {}:".format(i))
213217

214218
# Process predictions
215-
predictions_adv_class, predictions_adv_boxes, predictions_adv_class = extract_predictions(predictions_adv[i])
219+
predictions_adv_class, predictions_adv_boxes, _ = extract_predictions(predictions_adv[i])
216220

217221
# Plot predictions
218-
plot_image_with_boxes(img=image_adv[i].copy(), boxes=predictions_adv_boxes, pred_cls=predictions_adv_class)
222+
plot_image_with_boxes(img=image_adv[i], boxes=predictions_adv_boxes, pred_cls=predictions_adv_class)
219223

220224

221225
if __name__ == "__main__":

examples/get_started_yolo.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -115,37 +115,45 @@
115115
]
116116

117117

118-
def extract_predictions(predictions_, conf_thresh):
118+
def extract_predictions(predictions_, top_k):
119119
# Get the predicted class
120120
predictions_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(predictions_["labels"])]
121-
# print("\npredicted classes:", predictions_class)
122-
if len(predictions_class) < 1:
123-
return [], [], []
121+
124122
# Get the predicted bounding boxes
125123
predictions_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(predictions_["boxes"])]
126124

127125
# Get the predicted prediction score
128126
predictions_score = list(predictions_["scores"])
129-
# print("predicted score:", predictions_score)
127+
128+
# sort all lists according to scores
129+
# Combine into a list of tuples
130+
combined = list(zip(predictions_score, predictions_boxes, predictions_class))
131+
132+
# Sort by score (first element of tuple), descending
133+
combined_sorted = sorted(combined, key=lambda x: x[0], reverse=True)
134+
135+
# Unpack sorted tuples
136+
predictions_score, predictions_boxes, predictions_class = zip(*combined_sorted)
137+
138+
# Convert back to lists
139+
predictions_score = list(predictions_score)
140+
predictions_boxes = list(predictions_boxes)
141+
predictions_class = list(predictions_class) # Combine into a list of tuples
130142

131143
# Get a list of index with score greater than threshold
132-
threshold = conf_thresh
133-
predictions_t = [predictions_score.index(x) for x in predictions_score if x > threshold]
134-
if len(predictions_t) == 0:
135-
return [], [], []
136-
137-
# predictions in score order
138-
predictions_boxes = [predictions_boxes[i] for i in predictions_t]
139-
predictions_class = [predictions_class[i] for i in predictions_t]
140-
predictions_scores = [predictions_score[i] for i in predictions_t]
144+
predictions_t = top_k
145+
146+
predictions_boxes = predictions_boxes[:predictions_t]
147+
predictions_class = predictions_class[:predictions_t]
148+
predictions_scores = predictions_score[:predictions_t]
149+
141150
return predictions_class, predictions_boxes, predictions_scores
142151

143152

144153
def plot_image_with_boxes(img, boxes, pred_cls, title):
145-
plt.style.use("ggplot")
146-
text_size = 1
147-
text_th = 3
148-
rect_th = 1
154+
text_size = 2
155+
text_th = 2
156+
rect_th = 2
149157

150158
img = img.copy()
151159

@@ -175,18 +183,10 @@ def plot_image_with_boxes(img, boxes, pred_cls, title):
175183
plt.show()
176184

177185

178-
"""
179-
################# Evasion settings #################
180-
"""
181-
eps = 32
182-
eps_step = 2
183-
max_iter = 10
184-
185-
186186
"""
187187
################# Model definition #################
188188
"""
189-
MODEL = "yolov3" # OR yolov5
189+
MODEL = "yolov5" # OR yolov5
190190

191191

192192
if MODEL == "yolov3":
@@ -265,35 +265,37 @@ def forward(self, x, targets=None):
265265
"""
266266
response = requests.get("https://ultralytics.com/images/zidane.jpg")
267267
img = np.asarray(Image.open(BytesIO(response.content)).resize((640, 640)))
268-
img_reshape = img.transpose((2, 0, 1))
269-
image = np.stack([img_reshape], axis=0).astype(np.float32)
270-
x = image.copy()
268+
image = np.stack([img], axis=0).astype(np.float32)
269+
image_chw = np.transpose(image, (0, 3, 1, 2))
271270

272271
"""
273272
################# Evasion attack #################
274273
"""
275274

276-
attack = ProjectedGradientDescent(estimator=detector, eps=eps, eps_step=eps_step, max_iter=max_iter)
277-
image_adv = attack.generate(x=x, y=None)
275+
eps = 32
276+
attack = ProjectedGradientDescent(estimator=detector, eps=eps, eps_step=2, max_iter=10)
277+
image_adv_chw = attack.generate(x=image_chw, y=None)
278+
image_adv = np.transpose(image_adv_chw, (0, 2, 3, 1))
278279

279280
print("\nThe attack budget eps is {}".format(eps))
280-
print("The resulting maximal difference in pixel values is {}.".format(np.amax(np.abs(x - image_adv))))
281+
print("The resulting maximal difference in pixel values is {}.".format(np.amax(np.abs(image_chw - image_adv_chw))))
281282

282283
plt.axis("off")
283284
plt.title("adversarial image")
284-
plt.imshow(image_adv[0].transpose(1, 2, 0).astype(np.uint8), interpolation="nearest")
285+
plt.imshow(image_adv[0].astype(np.uint8), interpolation="nearest")
285286
plt.show()
286287

287-
threshold = 0.85 # 0.5
288-
dets = detector.predict(x)
289-
preds = extract_predictions(dets[0], threshold)
290-
plot_image_with_boxes(img=img, boxes=preds[1], pred_cls=preds[0], title="Predictions on original image")
288+
predictions = detector.predict(x=image_chw)
289+
predictions_class, predictions_boxes, _ = extract_predictions(predictions[0], top_k=3)
290+
plot_image_with_boxes(
291+
img=image[0], boxes=predictions_boxes, pred_cls=predictions_class, title="Predictions on original image"
292+
)
291293

292-
dets = detector.predict(image_adv)
293-
preds = extract_predictions(dets[0], threshold)
294+
predictions = detector.predict(image_adv_chw)
295+
predictions_class, predictions_boxes, d = extract_predictions(predictions[0], top_k=3)
294296
plot_image_with_boxes(
295-
img=image_adv[0].transpose(1, 2, 0).copy(),
296-
boxes=preds[1],
297-
pred_cls=preds[0],
297+
img=image_adv[0],
298+
boxes=predictions_boxes,
299+
pred_cls=predictions_class,
298300
title="Predictions on adversarial image",
299301
)

0 commit comments

Comments
 (0)