Skip to content

Commit d27f500

Browse files
committed
Add a flag to be used for marking the YOLOv8 model
1 parent 693e545 commit d27f500

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
"loss_rpn_box_reg",
6666
),
6767
device_type: str = "gpu",
68+
is_yolov8: bool = False,
6869
):
6970
"""
7071
Initialization.
@@ -92,6 +93,7 @@ def __init__(
9293
'loss_objectness', and 'loss_rpn_box_reg'.
9394
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
9495
if available otherwise run on CPU.
96+
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
9597
"""
9698
import re
9799
import torch
@@ -140,9 +142,10 @@ def __init__(
140142

141143
self._model: torch.nn.Module
142144
self._model.to(self._device)
143-
try:
145+
self.is_yolov8 = is_yolov8
146+
if self.is_yolov8:
144147
self._model.model.eval()
145-
except AttributeError:
148+
else:
146149
self._model.eval()
147150

148151
@property
@@ -406,9 +409,9 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> list[dict[s
406409
from torch.utils.data import TensorDataset, DataLoader
407410

408411
# Set model to evaluation mode
409-
try:
412+
if self.is_yolov8:
410413
self._model.model.eval()
411-
except AttributeError:
414+
else:
412415
self._model.eval()
413416

414417
# Apply preprocessing and convert to tensors

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
"loss_rpn_box_reg",
6565
),
6666
device_type: str = "gpu",
67+
is_yolov8: bool = False,
6768
):
6869
"""
6970
Initialization.
@@ -92,6 +93,7 @@ def __init__(
9293
'loss_objectness', and 'loss_rpn_box_reg'.
9394
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
9495
if available otherwise run on CPU.
96+
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
9597
"""
9698
super().__init__(
9799
model=model,
@@ -104,6 +106,7 @@ def __init__(
104106
preprocessing=preprocessing,
105107
attack_losses=attack_losses,
106108
device_type=device_type,
109+
is_yolov8=is_yolov8,
107110
)
108111

109112
def _translate_labels(self, labels: list[dict[str, "torch.Tensor"]]) -> "torch.Tensor":

notebooks/snal.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
"model = YOLO('yolov8m')\n",
7272
"py_model = PyTorchYolo(model=model,\n",
7373
" input_shape=(3, 640, 640),\n",
74-
" channels_first=True)\n",
74+
" channels_first=True,\n",
75+
" is_yolov8=True)\n",
7576
"\n",
7677
"# Define a custom function to collect patches from images\n",
7778
"def collect_patches_from_images(model: \"torch.nn.Module\",\n",

tests/attacks/evasion/test_steal_now_attack_later.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_generate(art_warning):
3535
import requests
3636

3737
model = YOLO("yolov8m")
38-
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
38+
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
3939

4040
# Define a custom function to collect patches from images
4141
def collect_patches_from_images(model, imgs):
@@ -192,7 +192,7 @@ def _loader(self, path):
192192
def test_check_params(art_warning):
193193
try:
194194
model = YOLO("yolov8m")
195-
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
195+
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
196196

197197
def dummy_func(model, imags):
198198
candidates_patch = []

0 commit comments

Comments
 (0)