@@ -35,23 +35,41 @@ def test_generate(art_warning):
35
35
from ultralytics .nn .modules import Conv
36
36
from ultralytics import YOLO
37
37
38
+ import importlib
39
+ import inspect
40
+ from torch .serialization import add_safe_globals
41
+
38
42
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
39
43
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
40
44
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
41
45
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
42
46
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
43
47
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
44
48
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
45
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
46
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
47
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
48
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
49
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
49
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
50
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
51
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
52
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
53
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
50
54
# torch.serialization.add_safe_globals([Conv])
51
55
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
52
56
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
53
57
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
54
58
59
+ yolo_modules = [
60
+ "ultralytics.nn.modules" ,
61
+ "ultralytics.nn.tasks" ,
62
+ "ultralytics.nn.autobackend" ,
63
+ ]
64
+ for module_name in yolo_modules :
65
+ try :
66
+ mod = importlib .import_module (module_name )
67
+ for name , obj in inspect .getmembers (mod ):
68
+ if inspect .isclass (obj ):
69
+ add_safe_globals ([obj ])
70
+ except ModuleNotFoundError :
71
+ pass # Some modules may not exist in all YOLO versions
72
+
55
73
model = YOLO ("yolov5su.pt" )
56
74
57
75
# Collect all unique classes used in the model
@@ -96,23 +114,41 @@ def test_check_params(art_warning):
96
114
from ultralytics import YOLO
97
115
from ultralytics .nn .modules import Conv
98
116
117
+ import importlib
118
+ import inspect
119
+ from torch .serialization import add_safe_globals
120
+
99
121
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
100
122
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
101
123
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
102
124
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
103
125
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
104
126
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
105
127
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
106
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
107
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
108
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
109
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
110
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
128
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
129
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
130
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
131
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
132
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
111
133
# torch.serialization.add_safe_globals([Conv])
112
134
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
113
135
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
114
136
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
115
137
138
+ yolo_modules = [
139
+ "ultralytics.nn.modules" ,
140
+ "ultralytics.nn.tasks" ,
141
+ "ultralytics.nn.autobackend" ,
142
+ ]
143
+ for module_name in yolo_modules :
144
+ try :
145
+ mod = importlib .import_module (module_name )
146
+ for name , obj in inspect .getmembers (mod ):
147
+ if inspect .isclass (obj ):
148
+ add_safe_globals ([obj ])
149
+ except ModuleNotFoundError :
150
+ pass # Some modules may not exist in all YOLO versions
151
+
116
152
model = YOLO ("yolov5su.pt" )
117
153
118
154
# Collect all unique classes used in the model
0 commit comments