Skip to content

Commit ad5a98f

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent f991702 commit ad5a98f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/attacks/evasion/test_overload_attack.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_generate(art_warning):
3232
try:
3333
import torch
3434
from ultralytics import YOLO
35-
from ultralytics.nn.tasks import DetectionModel
3635

3736
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
3837

@@ -68,13 +67,11 @@ def test_generate(art_warning):
6867
def test_check_params(art_warning):
6968
try:
7069
import torch
71-
from torch.serialization import add_safe_globals
7270
from ultralytics import YOLO
73-
from ultralytics.nn.tasks import DetectionModel
7471

75-
add_safe_globals([DetectionModel])
72+
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
7673

77-
model = YOLO("yolov5s.pt")
74+
model = YOLO("yolov5su.pt")
7875
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
7976

8077
with pytest.raises(ValueError):

tests/attacks/evasion/test_steal_now_attack_later.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_generate(art_warning):
3535
import torch
3636
import requests
3737
from ultralytics import YOLO
38-
from ultralytics.nn.tasks import DetectionModel
3938

4039
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
4140

@@ -198,8 +197,12 @@ def test_check_params(art_warning):
198197
try:
199198
# The ultralytics package does not support Python versions earlier than 3.8.
200199
# To avoid an import error with the TF 1.x pipeline, it is imported only within the function scope.
200+
import torch
201+
import requests
201202
from ultralytics import YOLO
202203

204+
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
205+
203206
model = YOLO("yolov8m")
204207
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
205208

0 commit comments

Comments
 (0)