@@ -35,7 +35,20 @@ def test_generate(art_warning):
35
35
from ultralytics import YOLO
36
36
37
37
torch .serialization .add_safe_globals ([torch .nn .modules .container .Sequential ])
38
+ torch .serialization .add_safe_globals ([torch .nn .modules .container .ModuleList ])
39
+ torch .serialization .add_safe_globals ([torch .nn .modules .pooling .MaxPool2d ])
40
+ torch .serialization .add_safe_globals ([torch .nn .modules .batchnorm .BatchNorm2d ])
41
+ torch .serialization .add_safe_globals ([torch .nn .modules .activation .SiLU ])
42
+ torch .serialization .add_safe_globals ([torch .nn .modules .conv .Conv2d ])
43
+ torch .serialization .add_safe_globals ([torch .nn .modules .upsampling .Upsample ])
38
44
torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
45
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
46
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
47
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
48
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
49
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .Conv ])
50
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .Concat ])
51
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .DFL ])
39
52
40
53
model = YOLO ("yolov5su.pt" )
41
54
py_model = PyTorchYolo (model = model , input_shape = (3 , 640 , 640 ), channels_first = True )
@@ -73,9 +86,23 @@ def test_check_params(art_warning):
73
86
from ultralytics import YOLO
74
87
75
88
torch .serialization .add_safe_globals ([torch .nn .modules .container .Sequential ])
89
+ torch .serialization .add_safe_globals ([torch .nn .modules .container .ModuleList ])
90
+ torch .serialization .add_safe_globals ([torch .nn .modules .pooling .MaxPool2d ])
91
+ torch .serialization .add_safe_globals ([torch .nn .modules .batchnorm .BatchNorm2d ])
92
+ torch .serialization .add_safe_globals ([torch .nn .modules .activation .SiLU ])
93
+ torch .serialization .add_safe_globals ([torch .nn .modules .conv .Conv2d ])
94
+ torch .serialization .add_safe_globals ([torch .nn .modules .upsampling .Upsample ])
76
95
torch .serialization .add_safe_globals ([ultralytics .nn .tasks .DetectionModel ])
96
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .SPPF ])
97
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .C3 ])
98
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Bottleneck ])
99
+ torch .serialization .add_safe_globals ([ultralytics .nn .tasks .Detect ])
100
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .Conv ])
101
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .Concat ])
102
+ torch .serialization .add_safe_globals ([ultralytics .nn .modules .DFL ])
77
103
78
104
model = YOLO ("yolov5su.pt" )
105
+ print (model )
79
106
py_model = PyTorchYolo (model = model , input_shape = (3 , 640 , 640 ), channels_first = True )
80
107
81
108
with pytest .raises (ValueError ):
0 commit comments