Skip to content

Commit 820f470

Browse files
authored
Update darknet2onnx.py
Add dynamic axes for multi-batch cases.
1 parent af69fa6 commit 820f470

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

tool/darknet2onnx.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tool.darknet2pytorch import Darknet
44

55

6-
def transform_to_onnx(cfgfile, weightfile, batch_size=1,dynamic):
6+
def transform_to_onnx(cfgfile, weightfile, batch_size=1, dynamic=False):
77
model = Darknet(cfgfile)
88

99
model.print_network()
@@ -14,8 +14,8 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1,dynamic):
1414

1515
x = torch.randn((batch_size, 3, model.height, model.width), requires_grad=True) # .cuda()
1616

17-
if dynamics:
18-
17+
if dynamic:
18+
1919
onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx".format(batch_size, model.height, model.width)
2020
input_names = ["input"]
2121
output_names = ['boxes', 'confs']
@@ -38,17 +38,16 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1,dynamic):
3838
else:
3939
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, model.height, model.width)
4040
torch.onnx.export(model,
41-
x,
42-
onnx_file_name,
43-
export_params=True,
44-
opset_version=11,
45-
do_constant_folding=True,
46-
input_names=['input'], output_names=['boxes', 'confs'],
47-
dynamic_axes=None)
41+
x,
42+
onnx_file_name,
43+
export_params=True,
44+
opset_version=11,
45+
do_constant_folding=True,
46+
input_names=['input'], output_names=['boxes', 'confs'],
47+
dynamic_axes=None)
4848

4949
print('Onnx model exporting done')
5050
return onnx_file_name
51-
5251

5352

5453
if __name__ == '__main__':

0 commit comments

Comments
 (0)