Skip to content

Commit 3e0382c

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent a131e39 commit 3e0382c

File tree

4 files changed

+131
-190
lines changed

4 files changed

+131
-190
lines changed

.github/actions/yolo/run.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed estimators/object_detection/te
88
pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_object_seeker_yolo.py --framework=pytorch --durations=0
99
if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed estimators/object_detection/test_object_seeker_yolo tests"; fi
1010

11+
pytest --cov-report=xml --cov=art --cov-append -q -vv tests/attacks/test_overload_attack.py --framework=pytorch --durations=0
12+
if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed attacks/test_overload_attack tests"; fi
13+
1114
exit ${exit_code}

examples/get_started_yolo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def plot_image_with_boxes(img, boxes, pred_cls, title):
147147
text_th = 3
148148
rect_th = 1
149149

150+
img = img.copy()
151+
150152
for i in range(len(boxes)):
151153
cv2.rectangle(
152154
img,
@@ -206,8 +208,10 @@ def forward(self, x, targets=None):
206208
else:
207209
return self.model(x)
208210

209-
model_path = "./yolov3.cfg"
210-
weights_path = "./yolov3.weights"
211+
# model_path = "./yolov3.cfg"
212+
# weights_path = "./yolov3.weights"
213+
model_path = "/tmp/PyTorch-YOLOv3/config/yolov3.cfg"
214+
weights_path = "/tmp/PyTorch-YOLOv3/weights/yolov3.weights"
211215
model = load_model(model_path=model_path, weights_path=weights_path)
212216

213217
model = Yolo(model)

tests/attacks/evasion/test_overload_attack.py

Lines changed: 0 additions & 188 deletions
This file was deleted.

tests/attacks/test_overload_attack.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
# import logging
19+
#
20+
# import pytest
21+
22+
from art.attacks.evasion.overload.overload import OverloadPyTorch
23+
24+
from tests.utils import ARTTestException
25+
from tests.estimators.object_detection.conftest import *
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
@pytest.mark.only_with_platform("pytorch")
31+
def test_generate(art_warning, get_pytorch_yolo):
32+
try:
33+
from io import BytesIO
34+
from PIL import Image
35+
import requests
36+
import torch
37+
from ultralytics import YOLO
38+
39+
threshold = 0.85
40+
41+
object_detector, _, _ = get_pytorch_yolo
42+
object_detector.set_params(input_shape=(3, 640, 640))
43+
44+
# Download a sample image
45+
target = "https://ultralytics.com/images/zidane.jpg"
46+
response = requests.get(target)
47+
org_img = np.asarray(Image.open(BytesIO(response.content)).resize((640, 640)))
48+
x_255 = np.stack([org_img.transpose((2, 0, 1))], axis=0).astype(np.uint8)
49+
x = x_255.astype(np.float32) / 255.0
50+
51+
y_pred = object_detector.predict(x=x)
52+
53+
attack = OverloadPyTorch(
54+
object_detector, eps=16.0 / 255.0, max_iter=10, num_grid=10, batch_size=1, threshold=threshold
55+
)
56+
57+
x_adv = attack.generate(x=x, y=y_pred)
58+
59+
assert x.shape == x_adv.shape
60+
assert np.min(x_adv) >= 0.0
61+
assert np.max(x_adv) <= 1.0
62+
63+
y_pred_adv = object_detector.predict(x=x_adv)
64+
65+
scores_list = list(y_pred[0]["scores"])
66+
scores_list_adv = list(y_pred_adv[0]["scores"])
67+
68+
scores_list_filtered = [scores_list.index(x) for x in scores_list if x > threshold]
69+
scores_list_adv_filtered = [scores_list_adv.index(x) for x in scores_list_adv if x > threshold]
70+
71+
assert len(scores_list_filtered) == 11
72+
assert len(scores_list_adv_filtered) == 178
73+
74+
except ARTTestException as e:
75+
art_warning(e)
76+
77+
78+
@pytest.mark.only_with_platform("pytorch")
79+
def test_check_params(art_warning, get_pytorch_yolo):
80+
try:
81+
object_detector, _, _ = get_pytorch_yolo
82+
object_detector.set_params(input_shape=(3, 640, 640))
83+
84+
with pytest.raises(ValueError):
85+
_ = OverloadPyTorch(
86+
estimator=object_detector, eps=-1.0, max_iter=5, num_grid=10, batch_size=1, threshold=0.5
87+
)
88+
with pytest.raises(ValueError):
89+
_ = OverloadPyTorch(
90+
estimator=object_detector, eps=2.0, max_iter=5, num_grid=10, batch_size=1, threshold=0.5
91+
)
92+
with pytest.raises(TypeError):
93+
_ = OverloadPyTorch(
94+
estimator=object_detector, eps=8 / 255.0, max_iter=1.0, num_grid=10, batch_size=1, threshold=0.5
95+
)
96+
with pytest.raises(ValueError):
97+
_ = OverloadPyTorch(
98+
estimator=object_detector, eps=8 / 255.0, max_iter=0, num_grid=10, batch_size=1, threshold=0.5
99+
)
100+
with pytest.raises(TypeError):
101+
_ = OverloadPyTorch(
102+
estimator=object_detector, eps=8 / 255.0, max_iter=5, num_grid=1.0, batch_size=1, threshold=0.5
103+
)
104+
with pytest.raises(ValueError):
105+
_ = OverloadPyTorch(
106+
estimator=object_detector, eps=8 / 255.0, max_iter=5, num_grid=0, batch_size=1, threshold=0.5
107+
)
108+
with pytest.raises(TypeError):
109+
_ = OverloadPyTorch(
110+
estimator=object_detector, eps=8 / 255.0, max_iter=5, num_grid=10, batch_size=1.0, threshold=0.5
111+
)
112+
with pytest.raises(ValueError):
113+
_ = OverloadPyTorch(
114+
estimator=object_detector, eps=8 / 255.0, max_iter=5, num_grid=0, batch_size=0, threshold=0.5
115+
)
116+
with pytest.raises(ValueError):
117+
_ = OverloadPyTorch(
118+
estimator=object_detector, eps=8 / 255.0, max_iter=5, num_grid=0, batch_size=1, threshold=1.5
119+
)
120+
121+
except ARTTestException as e:
122+
art_warning(e)

0 commit comments

Comments
 (0)