33from tool .darknet2pytorch import Darknet
44
55
6- def transform_to_onnx (cfgfile , weightfile , batch_size = 1 ):
6+ def transform_to_onnx (cfgfile , weightfile , batch_size = 1 , dynamic ):
77 model = Darknet (cfgfile )
88
99 model .print_network ()
@@ -14,11 +14,30 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
1414
1515 x = torch .randn ((batch_size , 3 , model .height , model .width ), requires_grad = True ) # .cuda()
1616
17- onnx_file_name = "yolov4_{}_3_{}_{}.onnx" .format (batch_size , model .height , model .width )
17+ if dynamics :
18+
19+ onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx" .format (batch_size , model .height , model .width )
20+ input_names = ["input" ]
21+ output_names = ['boxes' , 'confs' ]
1822
19- # Export the model
20- print ('Export the onnx model ...' )
21- torch .onnx .export (model ,
23+ dynamic_axes = {"input" : {0 : "batch_size" }, "boxes" : {0 : "batch_size" }, "confs" : {0 : "batch_size" }}
24+ # Export the model
25+ print ('Export the onnx model ...' )
26+ torch .onnx .export (model ,
27+ x ,
28+ onnx_file_name ,
29+ export_params = True ,
30+ opset_version = 11 ,
31+ do_constant_folding = True ,
32+ input_names = input_names , output_names = output_names ,
33+ dynamic_axes = dynamic_axes )
34+
35+ print ('Onnx model exporting done' )
36+ return onnx_file_name
37+
38+ else :
39+ onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx" .format (batch_size , model .height , model .width )
40+ torch .onnx .export (model ,
2241 x ,
2342 onnx_file_name ,
2443 export_params = True ,
@@ -27,8 +46,9 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
2746 input_names = ['input' ], output_names = ['boxes' , 'confs' ],
2847 dynamic_axes = None )
2948
30- print ('Onnx model exporting done' )
31- return onnx_file_name
49+ print ('Onnx model exporting done' )
50+ return onnx_file_name
51+
3252
3353
3454if __name__ == '__main__' :
@@ -41,6 +61,12 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
4161 weightfile = sys .argv [2 ]
4262 batch_size = int (sys .argv [3 ])
4363 transform_to_onnx (cfgfile , weightfile , batch_size )
64+ elif len (sys .argv ) == 5 :
65+ cfgfile = sys .argv [1 ]
66+ weightfile = sys .argv [2 ]
67+ batch_size = int (sys .argv [3 ])
68+ dynamic = True if sys .argv [4 ] == 'True' else False
69+ transform_to_onnx (cfgfile , weightfile , batch_size , dynamic )
4470 else :
4571 print ('Please execute this script this way:\n ' )
4672 print (' python darknet2onnx.py <cfgFile> <weightFile>' )
0 commit comments