Skip to content

Commit d0f7e90

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent 49c5490 commit d0f7e90

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

tests/attacks/evasion/test_overload_attack.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@ def test_generate(art_warning):
3535
from ultralytics import YOLO
3636

3737
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
38+
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
39+
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
40+
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
41+
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
42+
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
43+
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
3844
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
45+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
46+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
47+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
48+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
49+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
50+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
51+
torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
3952

4053
model = YOLO("yolov5su.pt")
4154
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
@@ -73,9 +86,23 @@ def test_check_params(art_warning):
7386
from ultralytics import YOLO
7487

7588
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
89+
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
90+
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
91+
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
92+
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
93+
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
94+
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
7695
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
96+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
97+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
98+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
99+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
100+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
101+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
102+
torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
77103

78104
model = YOLO("yolov5su.pt")
105+
print(model)
79106
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
80107

81108
with pytest.raises(ValueError):

tests/attacks/evasion/test_steal_now_attack_later.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,20 @@ def test_generate(art_warning):
3838
from ultralytics import YOLO
3939

4040
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
41+
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
42+
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
43+
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
44+
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
45+
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
46+
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
4147
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
48+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
49+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
50+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
51+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
52+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
53+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
54+
torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
4255

4356
model = YOLO("yolov8m")
4457
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
@@ -205,9 +218,24 @@ def test_check_params(art_warning):
205218
from ultralytics import YOLO
206219

207220
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
221+
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
222+
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
223+
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
224+
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
225+
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
226+
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
208227
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
228+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
229+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
230+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
231+
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
232+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
233+
torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
234+
torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
209235

210236
model = YOLO("yolov8m")
237+
print(model)
238+
sdf
211239
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)
212240

213241
def dummy_func(model, imags):

0 commit comments

Comments
 (0)