@@ -35,24 +35,31 @@ def test_generate(art_warning):
35
35
from ultralytics .nn .modules import Conv
36
36
from ultralytics import YOLO
37
37
38
- torch .serialization .add_safe_globals ([torch .nn .modules .container .Sequential ])
39
- torch .serialization .add_safe_globals ([torch .nn .modules .container .ModuleList ])
40
- torch .serialization .add_safe_globals ([torch .nn .modules .pooling .MaxPool2d ])
41
- torch .serialization .add_safe_globals ([torch .nn .modules .batchnorm .BatchNorm2d ])
42
- torch .serialization .add_safe_globals ([torch .nn .modules .activation .SiLU ])
43
- torch .serialization .add_safe_globals ([torch .nn .modules .conv .Conv2d ])
44
- torch .serialization .add_safe_globals ([torch .nn .modules .upsampling .Upsample ])
45
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
46
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
47
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
48
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
49
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
50
- torch .serialization .add_safe_globals ([Conv ])
51
- torch .serialization .add_safe_globals ([ultralytics .nn .modules .Conv ])
38
+ # torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
39
+ # torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
40
+ # torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
41
+ # torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
42
+ # torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
43
+ # torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
44
+ # torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
45
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
46
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
47
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
48
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
49
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
50
+ # torch.serialization.add_safe_globals([Conv])
51
+ # torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
52
52
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
53
53
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
54
54
55
55
model = YOLO ("yolov5su.pt" )
56
+
57
+ # Collect all unique classes used in the model
58
+ all_classes = set (type (module ) for module in model .modules ())
59
+
60
+ # Add them to safe_globals
61
+ torch .serialization .add_safe_globals (list (all_classes ))
62
+
56
63
py_model = PyTorchYolo (model = model , input_shape = (3 , 640 , 640 ), channels_first = True )
57
64
58
65
# Download a sample image
@@ -89,24 +96,31 @@ def test_check_params(art_warning):
89
96
from ultralytics import YOLO
90
97
from ultralytics .nn .modules import Conv
91
98
92
- torch .serialization .add_safe_globals ([torch .nn .modules .container .Sequential ])
93
- torch .serialization .add_safe_globals ([torch .nn .modules .container .ModuleList ])
94
- torch .serialization .add_safe_globals ([torch .nn .modules .pooling .MaxPool2d ])
95
- torch .serialization .add_safe_globals ([torch .nn .modules .batchnorm .BatchNorm2d ])
96
- torch .serialization .add_safe_globals ([torch .nn .modules .activation .SiLU ])
97
- torch .serialization .add_safe_globals ([torch .nn .modules .conv .Conv2d ])
98
- torch .serialization .add_safe_globals ([torch .nn .modules .upsampling .Upsample ])
99
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
100
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
101
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
102
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
103
- torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
104
- torch .serialization .add_safe_globals ([Conv ])
105
- torch .serialization .add_safe_globals ([ultralytics .nn .modules .Conv ])
99
+ # torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
100
+ # torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
101
+ # torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
102
+ # torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
103
+ # torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
104
+ # torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
105
+ # torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
106
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
107
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
108
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
109
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
110
+ # torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
111
+ # torch.serialization.add_safe_globals([Conv])
112
+ # torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
106
113
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
107
114
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
108
115
109
116
model = YOLO ("yolov5su.pt" )
117
+
118
+ # Collect all unique classes used in the model
119
+ all_classes = set (type (module ) for module in model .modules ())
120
+
121
+ # Add them to safe_globals
122
+ torch .serialization .add_safe_globals (list (all_classes ))
123
+
110
124
py_model = PyTorchYolo (model = model , input_shape = (3 , 640 , 640 ), channels_first = True )
111
125
112
126
with pytest .raises (ValueError ):
0 commit comments