|
| 1 | +""" |
| 2 | +The script demonstrates a simple example of using ART with YOLO (versions 3 and 5). |
| 3 | +The example loads a YOLO model pretrained on the COCO dataset |
| 4 | +and creates an adversarial example using Projected Gradient Descent method. |
| 5 | +
|
| 6 | +- To use Yolov3, run: |
| 7 | + pip install pytorchyolo |
| 8 | +
|
| 9 | +- To use Yolov5, run: |
| 10 | + pip install yolov5 |
| 11 | +
|
| 12 | +Note: If pytorchyolo throws an error in pytorchyolo/utils/loss.py, add before line 174 in that file, the following: |
| 13 | + gain = gain.to(torch.int64) |
| 14 | +""" |
| 15 | + |
| 16 | +import requests |
| 17 | +import numpy as np |
| 18 | +from PIL import Image |
| 19 | +from io import BytesIO |
| 20 | +import torch |
| 21 | + |
| 22 | +from art.estimators.object_detection.pytorch_yolo import PyTorchYolo |
| 23 | +from art.attacks.evasion import ProjectedGradientDescent |
| 24 | + |
| 25 | +import cv2 |
| 26 | +import matplotlib |
| 27 | +import matplotlib.pyplot as plt |
| 28 | + |
| 29 | + |
| 30 | +""" |
| 31 | +################# Helper functions and labels ################# |
| 32 | +""" |
| 33 | + |
| 34 | +COCO_INSTANCE_CATEGORY_NAMES = [ |
| 35 | + "person", |
| 36 | + "bicycle", |
| 37 | + "car", |
| 38 | + "motorcycle", |
| 39 | + "airplane", |
| 40 | + "bus", |
| 41 | + "train", |
| 42 | + "truck", |
| 43 | + "boat", |
| 44 | + "traffic light", |
| 45 | + "fire hydrant", |
| 46 | + "stop sign", |
| 47 | + "parking meter", |
| 48 | + "bench", |
| 49 | + "bird", |
| 50 | + "cat", |
| 51 | + "dog", |
| 52 | + "horse", |
| 53 | + "sheep", |
| 54 | + "cow", |
| 55 | + "elephant", |
| 56 | + "bear", |
| 57 | + "zebra", |
| 58 | + "giraffe", |
| 59 | + "backpack", |
| 60 | + "umbrella", |
| 61 | + "handbag", |
| 62 | + "tie", |
| 63 | + "suitcase", |
| 64 | + "frisbee", |
| 65 | + "skis", |
| 66 | + "snowboard", |
| 67 | + "sports ball", |
| 68 | + "kite", |
| 69 | + "baseball bat", |
| 70 | + "baseball glove", |
| 71 | + "skateboard", |
| 72 | + "surfboard", |
| 73 | + "tennis racket", |
| 74 | + "bottle", |
| 75 | + "wine glass", |
| 76 | + "cup", |
| 77 | + "fork", |
| 78 | + "knife", |
| 79 | + "spoon", |
| 80 | + "bowl", |
| 81 | + "banana", |
| 82 | + "apple", |
| 83 | + "sandwich", |
| 84 | + "orange", |
| 85 | + "broccoli", |
| 86 | + "carrot", |
| 87 | + "hot dog", |
| 88 | + "pizza", |
| 89 | + "donut", |
| 90 | + "cake", |
| 91 | + "chair", |
| 92 | + "couch", |
| 93 | + "potted plant", |
| 94 | + "bed", |
| 95 | + "dining table", |
| 96 | + "toilet", |
| 97 | + "tv", |
| 98 | + "laptop", |
| 99 | + "mouse", |
| 100 | + "remote", |
| 101 | + "keyboard", |
| 102 | + "cell phone", |
| 103 | + "microwave", |
| 104 | + "oven", |
| 105 | + "toaster", |
| 106 | + "sink", |
| 107 | + "refrigerator", |
| 108 | + "book", |
| 109 | + "clock", |
| 110 | + "vase", |
| 111 | + "scissors", |
| 112 | + "teddy bear", |
| 113 | + "hair drier", |
| 114 | + "toothbrush", |
| 115 | +] |
| 116 | + |
| 117 | + |
| 118 | +def extract_predictions(predictions_, conf_thresh): |
| 119 | + # Get the predicted class |
| 120 | + predictions_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(predictions_["labels"])] |
| 121 | + # print("\npredicted classes:", predictions_class) |
| 122 | + if len(predictions_class) < 1: |
| 123 | + return [], [], [] |
| 124 | + # Get the predicted bounding boxes |
| 125 | + predictions_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(predictions_["boxes"])] |
| 126 | + |
| 127 | + # Get the predicted prediction score |
| 128 | + predictions_score = list(predictions_["scores"]) |
| 129 | + # print("predicted score:", predictions_score) |
| 130 | + |
| 131 | + # Get a list of index with score greater than threshold |
| 132 | + threshold = conf_thresh |
| 133 | + predictions_t = [predictions_score.index(x) for x in predictions_score if x > threshold] |
| 134 | + if len(predictions_t) > 0: |
| 135 | + predictions_t = predictions_t # [-1] #indices where score over threshold |
| 136 | + else: |
| 137 | + # no predictions esxceeding threshold |
| 138 | + return [], [], [] |
| 139 | + # predictions in score order |
| 140 | + predictions_boxes = [predictions_boxes[i] for i in predictions_t] |
| 141 | + predictions_class = [predictions_class[i] for i in predictions_t] |
| 142 | + predictions_scores = [predictions_score[i] for i in predictions_t] |
| 143 | + return predictions_class, predictions_boxes, predictions_scores |
| 144 | + |
| 145 | + |
| 146 | +def plot_image_with_boxes(img, boxes, pred_cls, title): |
| 147 | + plt.style.use("ggplot") |
| 148 | + text_size = 1 |
| 149 | + text_th = 3 |
| 150 | + rect_th = 1 |
| 151 | + |
| 152 | + for i in range(len(boxes)): |
| 153 | + cv2.rectangle( |
| 154 | + img, |
| 155 | + (int(boxes[i][0][0]), int(boxes[i][0][1])), |
| 156 | + (int(boxes[i][1][0]), int(boxes[i][1][1])), |
| 157 | + color=(0, 255, 0), |
| 158 | + thickness=rect_th, |
| 159 | + ) |
| 160 | + # Write the prediction class |
| 161 | + cv2.putText( |
| 162 | + img, |
| 163 | + pred_cls[i], |
| 164 | + (int(boxes[i][0][0]), int(boxes[i][0][1])), |
| 165 | + cv2.FONT_HERSHEY_SIMPLEX, |
| 166 | + text_size, |
| 167 | + (0, 255, 0), |
| 168 | + thickness=text_th, |
| 169 | + ) |
| 170 | + |
| 171 | + plt.figure() |
| 172 | + plt.axis("off") |
| 173 | + plt.title(title) |
| 174 | + plt.imshow(img.astype(np.uint8), interpolation="nearest") |
| 175 | + plt.show() |
| 176 | + |
| 177 | + |
| 178 | +""" |
| 179 | +################# Evasion settings ################# |
| 180 | +""" |
| 181 | +eps = 32 |
| 182 | +eps_step = 2 |
| 183 | +max_iter = 10 |
| 184 | + |
| 185 | + |
| 186 | +""" |
| 187 | +################# Model definition ################# |
| 188 | +""" |
| 189 | +MODEL = "yolov3" # OR yolov5 |
| 190 | + |
| 191 | + |
| 192 | +if MODEL == "yolov3": |
| 193 | + |
| 194 | + from pytorchyolo.utils.loss import compute_loss |
| 195 | + from pytorchyolo.models import load_model |
| 196 | + |
| 197 | + class Yolo(torch.nn.Module): |
| 198 | + def __init__(self, model): |
| 199 | + super().__init__() |
| 200 | + self.model = model |
| 201 | + |
| 202 | + def forward(self, x, targets=None): |
| 203 | + if self.training: |
| 204 | + outputs = self.model(x) |
| 205 | + loss, loss_components = compute_loss(outputs, targets, self.model) |
| 206 | + loss_components_dict = {"loss_total": loss} |
| 207 | + return loss_components_dict |
| 208 | + else: |
| 209 | + return self.model(x) |
| 210 | + |
| 211 | + model_path = "./yolov3.cfg" |
| 212 | + weights_path = "./yolov3.weights" |
| 213 | + model = load_model(model_path=model_path, weights_path=weights_path) |
| 214 | + |
| 215 | + model = Yolo(model) |
| 216 | + |
| 217 | + detector = PyTorchYolo( |
| 218 | + model=model, device_type="cpu", input_shape=(3, 640, 640), clip_values=(0, 255), attack_losses=("loss_total",) |
| 219 | + ) |
| 220 | + |
| 221 | +elif MODEL == "yolov5": |
| 222 | + |
| 223 | + import yolov5 |
| 224 | + from yolov5.utils.loss import ComputeLoss |
| 225 | + |
| 226 | + matplotlib.use("TkAgg") |
| 227 | + |
| 228 | + class Yolo(torch.nn.Module): |
| 229 | + def __init__(self, model): |
| 230 | + super().__init__() |
| 231 | + self.model = model |
| 232 | + self.model.hyp = { |
| 233 | + "box": 0.05, |
| 234 | + "obj": 1.0, |
| 235 | + "cls": 0.5, |
| 236 | + "anchor_t": 4.0, |
| 237 | + "cls_pw": 1.0, |
| 238 | + "obj_pw": 1.0, |
| 239 | + "fl_gamma": 0.0, |
| 240 | + } |
| 241 | + self.compute_loss = ComputeLoss(self.model.model.model) |
| 242 | + |
| 243 | + def forward(self, x, targets=None): |
| 244 | + if self.training: |
| 245 | + outputs = self.model.model.model(x) |
| 246 | + loss, loss_items = self.compute_loss(outputs, targets) |
| 247 | + loss_components_dict = {"loss_total": loss} |
| 248 | + return loss_components_dict |
| 249 | + else: |
| 250 | + return self.model(x) |
| 251 | + |
| 252 | + model = yolov5.load("yolov5s.pt") |
| 253 | + |
| 254 | + model = Yolo(model) |
| 255 | + |
| 256 | + detector = PyTorchYolo( |
| 257 | + model=model, device_type="cpu", input_shape=(3, 640, 640), clip_values=(0, 255), attack_losses=("loss_total",) |
| 258 | + ) |
| 259 | + |
| 260 | + |
| 261 | +""" |
| 262 | +################# Example image ################# |
| 263 | +""" |
| 264 | +response = requests.get("https://ultralytics.com/images/zidane.jpg") |
| 265 | +img = np.asarray(Image.open(BytesIO(response.content)).resize((640, 640))) |
| 266 | +img_reshape = img.transpose((2, 0, 1)) |
| 267 | +image = np.stack([img_reshape], axis=0).astype(np.float32) |
| 268 | +x = image.copy() |
| 269 | + |
| 270 | +""" |
| 271 | +################# Evasion attack ################# |
| 272 | +""" |
| 273 | + |
| 274 | +attack = ProjectedGradientDescent(estimator=detector, eps=eps, eps_step=eps_step, max_iter=max_iter) |
| 275 | +image_adv = attack.generate(x=x, y=None) |
| 276 | + |
| 277 | +print("\nThe attack budget eps is {}".format(eps)) |
| 278 | +print("The resulting maximal difference in pixel values is {}.".format(np.amax(np.abs(x - image_adv)))) |
| 279 | + |
| 280 | +plt.axis("off") |
| 281 | +plt.title("adversarial image") |
| 282 | +plt.imshow(image_adv[0].transpose(1, 2, 0).astype(np.uint8), interpolation="nearest") |
| 283 | +plt.show() |
| 284 | + |
| 285 | +threshold = 0.85 # 0.5 |
| 286 | +dets = detector.predict(x) |
| 287 | +preds = extract_predictions(dets[0], threshold) |
| 288 | +plot_image_with_boxes(img=img, boxes=preds[1], pred_cls=preds[0], title="Predictions on original image") |
| 289 | + |
| 290 | +dets = detector.predict(image_adv) |
| 291 | +preds = extract_predictions(dets[0], threshold) |
| 292 | +plot_image_with_boxes( |
| 293 | + img=image_adv[0].transpose(1, 2, 0).copy(), |
| 294 | + boxes=preds[1], |
| 295 | + pred_cls=preds[0], |
| 296 | + title="Predictions on adversarial image", |
| 297 | +) |
0 commit comments