Skip to content

Commit 49c5490

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent ad5a98f commit 49c5490

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

tests/attacks/evasion/test_overload_attack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
def test_generate(art_warning):
3232
try:
3333
import torch
34+
import ultralytics
3435
from ultralytics import YOLO
3536

3637
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
38+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
3739

3840
model = YOLO("yolov5su.pt")
3941
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
@@ -67,9 +69,11 @@ def test_generate(art_warning):
6769
def test_check_params(art_warning):
6870
try:
6971
import torch
72+
import ultralytics
7073
from ultralytics import YOLO
7174

7275
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
76+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
7377

7478
model = YOLO("yolov5su.pt")
7579
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)

tests/attacks/evasion/test_steal_now_attack_later.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def test_generate(art_warning):
3434
# To avoid an import error with the TF 1.x pipeline, it is imported only within the function scope.
3535
import torch
3636
import requests
37+
import ultralytics
3738
from ultralytics import YOLO
3839

3940
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
41+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
4042

4143
model = YOLO("yolov8m")
4244
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
@@ -199,9 +201,11 @@ def test_check_params(art_warning):
199201
# To avoid an import error with the TF 1.x pipeline, it is imported only within the function scope.
200202
import torch
201203
import requests
204+
import ultralytics
202205
from ultralytics import YOLO
203206

204207
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
208+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
205209

206210
model = YOLO("yolov8m")
207211
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)

0 commit comments

Comments
 (0)