Skip to content

Commit 1f2a823

Browse files
authored
Update darknet2onnx.py
Add the dynamics axes for multi-batch cases.
1 parent cc78500 commit 1f2a823

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

tool/darknet2onnx.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tool.darknet2pytorch import Darknet
44

55

6-
def transform_to_onnx(cfgfile, weightfile, batch_size=1):
6+
def transform_to_onnx(cfgfile, weightfile, batch_size=1,dynamic):
77
model = Darknet(cfgfile)
88

99
model.print_network()
@@ -14,11 +14,30 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
1414

1515
x = torch.randn((batch_size, 3, model.height, model.width), requires_grad=True) # .cuda()
1616

17-
onnx_file_name = "yolov4_{}_3_{}_{}.onnx".format(batch_size, model.height, model.width)
17+
if dynamics:
18+
19+
onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx".format(batch_size, model.height, model.width)
20+
input_names = ["input"]
21+
output_names = ['boxes', 'confs']
1822

19-
# Export the model
20-
print('Export the onnx model ...')
21-
torch.onnx.export(model,
23+
dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
24+
# Export the model
25+
print('Export the onnx model ...')
26+
torch.onnx.export(model,
27+
x,
28+
onnx_file_name,
29+
export_params=True,
30+
opset_version=11,
31+
do_constant_folding=True,
32+
input_names=input_names, output_names=output_names,
33+
dynamic_axes=dynamic_axes)
34+
35+
print('Onnx model exporting done')
36+
return onnx_file_name
37+
38+
else:
39+
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, model.height, model.width)
40+
torch.onnx.export(model,
2241
x,
2342
onnx_file_name,
2443
export_params=True,
@@ -27,8 +46,9 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
2746
input_names=['input'], output_names=['boxes', 'confs'],
2847
dynamic_axes=None)
2948

30-
print('Onnx model exporting done')
31-
return onnx_file_name
49+
print('Onnx model exporting done')
50+
return onnx_file_name
51+
3252

3353

3454
if __name__ == '__main__':
@@ -41,6 +61,12 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
4161
weightfile = sys.argv[2]
4262
batch_size = int(sys.argv[3])
4363
transform_to_onnx(cfgfile, weightfile, batch_size)
64+
elif len(sys.argv) == 5:
65+
cfgfile = sys.argv[1]
66+
weightfile = sys.argv[2]
67+
batch_size = int(sys.argv[3])
68+
dynamic = True if sys.argv[4] == 'True' else False
69+
transform_to_onnx(cfgfile, weightfile, batch_size, dynamic)
4470
else:
4571
print('Please execute this script this way:\n')
4672
print(' python darknet2onnx.py <cfgFile> <weightFile>')

0 commit comments

Comments
 (0)