22import torch
33
44
5- def transform_to_onnx (cfgfile , weightfile , batch_size = 1 ):
5+ def transform_to_onnx (cfgfile , weightfile , batch_size = 1 , dynamics = False ):
66 model = Darknet (cfgfile )
77
88 model .print_network ()
@@ -13,20 +13,39 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
1313
1414 x = torch .randn ((batch_size , 3 , model .height , model .width ), requires_grad = True ) # .cuda()
1515
16- onnx_file_name = "yolov4_{}_3_{}_{}.onnx" .format (batch_size , model .height , model .width )
16+ if dynamics :
17+ onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx" .format (batch_size , model .height , model .width )
18+ input_names = ["input" ]
19+ output_names = ['boxes' , 'confs' ]
1720
18- # Export the model
19- print ('Export the onnx model ...' )
20- torch .onnx .export (model ,
21- x ,
22- onnx_file_name ,
23- export_params = True ,
24- opset_version = 11 ,
25- do_constant_folding = True ,
26- dynamic_axes = None )
21+ dynamic_axes = {"input" : {0 : "batch_size" }, "boxes" : {0 : "batch_size" }, "confs" : {0 : "batch_size" }}
22+ # Export the model
23+
24+ print ('Export the onnx model ...' )
25+ torch .onnx .export (model ,
26+ x ,
27+ onnx_file_name ,
28+ export_params = True ,
29+ opset_version = 11 ,
30+ do_constant_folding = True ,
31+ input_names = input_names , output_names = output_names ,
32+ dynamic_axes = dynamic_axes )
2733
28- print ('Onnx model exporting done' )
29- return onnx_file_name
34+ print ('Onnx model exporting done' )
35+ return onnx_file_name
36+
37+ else :
38+ onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx" .format (batch_size , model .height , model .width )
39+ torch .onnx .export (model ,
40+ x ,
41+ onnx_file_name ,
42+ export_params = True ,
43+ opset_version = 11 ,
44+ do_constant_folding = True ,
45+ dynamic_axes = None )
46+
47+ print ('Onnx model exporting done' )
48+ return onnx_file_name
3049
3150
3251if __name__ == '__main__' :
@@ -39,6 +58,12 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
3958 weightfile = sys .argv [2 ]
4059 batch_size = int (sys .argv [3 ])
4160 transform_to_onnx (cfgfile , weightfile , batch_size )
61+ elif len (sys .argv ) == 5 :
62+ cfgfile = sys .argv [1 ]
63+ weightfile = sys .argv [2 ]
64+ batch_size = int (sys .argv [3 ])
65+ dynamics = True if sys .argv [4 ] == 'True' else False
66+ transform_to_onnx (cfgfile , weightfile , batch_size , dynamics )
4267 else :
4368 print ('Please execute this script this way:\n ' )
4469 print (' python darknet2onnx.py <cfgFile> <weightFile>' )
0 commit comments