Skip to content

Commit a131e39

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent 36e0bb5 commit a131e39

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

tests/attacks/evasion/test_overload_attack.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,41 @@ def test_generate(art_warning):
3535
from ultralytics.nn.modules import Conv
3636
from ultralytics import YOLO
3737

38+
import importlib
39+
import inspect
40+
from torch.serialization import add_safe_globals
41+
3842
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
3943
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
4044
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
4145
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
4246
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
4347
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
4448
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
45-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
46-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
47-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
48-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
49-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
49+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
50+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
51+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
52+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
53+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
5054
# torch.serialization.add_safe_globals([Conv])
5155
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
5256
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
5357
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
5458

59+
yolo_modules = [
60+
"ultralytics.nn.modules",
61+
"ultralytics.nn.tasks",
62+
"ultralytics.nn.autobackend",
63+
]
64+
for module_name in yolo_modules:
65+
try:
66+
mod = importlib.import_module(module_name)
67+
for name, obj in inspect.getmembers(mod):
68+
if inspect.isclass(obj):
69+
add_safe_globals([obj])
70+
except ModuleNotFoundError:
71+
pass # Some modules may not exist in all YOLO versions
72+
5573
model = YOLO("yolov5su.pt")
5674

5775
# Collect all unique classes used in the model
@@ -96,23 +114,41 @@ def test_check_params(art_warning):
96114
from ultralytics import YOLO
97115
from ultralytics.nn.modules import Conv
98116

117+
import importlib
118+
import inspect
119+
from torch.serialization import add_safe_globals
120+
99121
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
100122
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
101123
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
102124
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
103125
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
104126
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
105127
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
106-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
107-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
108-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
109-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
110-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
128+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
129+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
130+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
131+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
132+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
111133
# torch.serialization.add_safe_globals([Conv])
112134
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
113135
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
114136
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
115137

138+
yolo_modules = [
139+
"ultralytics.nn.modules",
140+
"ultralytics.nn.tasks",
141+
"ultralytics.nn.autobackend",
142+
]
143+
for module_name in yolo_modules:
144+
try:
145+
mod = importlib.import_module(module_name)
146+
for name, obj in inspect.getmembers(mod):
147+
if inspect.isclass(obj):
148+
add_safe_globals([obj])
149+
except ModuleNotFoundError:
150+
pass # Some modules may not exist in all YOLO versions
151+
116152
model = YOLO("yolov5su.pt")
117153

118154
# Collect all unique classes used in the model

0 commit comments

Comments
 (0)