Skip to content

Commit cc78500

Browse files
authored
Update pytorch2onnx.py
Add the exporting graph as dynamics graph. In Jetson Xavier JetPack 4.4, multi-batch engine could not be built with statics engine.
1 parent 9904574 commit cc78500

File tree

1 file changed

+38
-13
lines changed

1 file changed

+38
-13
lines changed

tool/pytorch2onnx.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44

5-
def transform_to_onnx(cfgfile, weightfile, batch_size=1):
5+
def transform_to_onnx(cfgfile, weightfile, batch_size=1, dynamics=False):
66
model = Darknet(cfgfile)
77

88
model.print_network()
@@ -13,20 +13,39 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
1313

1414
x = torch.randn((batch_size, 3, model.height, model.width), requires_grad=True) # .cuda()
1515

16-
onnx_file_name = "yolov4_{}_3_{}_{}.onnx".format(batch_size, model.height, model.width)
16+
if dynamics:
17+
onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx".format(batch_size, model.height, model.width)
18+
input_names = ["input"]
19+
output_names = ['boxes', 'confs']
1720

18-
# Export the model
19-
print('Export the onnx model ...')
20-
torch.onnx.export(model,
21-
x,
22-
onnx_file_name,
23-
export_params=True,
24-
opset_version=11,
25-
do_constant_folding=True,
26-
dynamic_axes=None)
21+
dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
22+
# Export the model
23+
24+
print('Export the onnx model ...')
25+
torch.onnx.export(model,
26+
x,
27+
onnx_file_name,
28+
export_params=True,
29+
opset_version=11,
30+
do_constant_folding=True,
31+
input_names=input_names, output_names=output_names,
32+
dynamic_axes=dynamic_axes)
2733

28-
print('Onnx model exporting done')
29-
return onnx_file_name
34+
print('Onnx model exporting done')
35+
return onnx_file_name
36+
37+
else:
38+
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, model.height, model.width)
39+
torch.onnx.export(model,
40+
x,
41+
onnx_file_name,
42+
export_params=True,
43+
opset_version=11,
44+
do_constant_folding=True,
45+
dynamic_axes=None)
46+
47+
print('Onnx model exporting done')
48+
return onnx_file_name
3049

3150

3251
if __name__ == '__main__':
@@ -39,6 +58,12 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1):
3958
weightfile = sys.argv[2]
4059
batch_size = int(sys.argv[3])
4160
transform_to_onnx(cfgfile, weightfile, batch_size)
61+
elif len(sys.argv) == 5:
62+
cfgfile = sys.argv[1]
63+
weightfile = sys.argv[2]
64+
batch_size = int(sys.argv[3])
65+
dynamics = True if sys.argv[4] == 'True' else False
66+
transform_to_onnx(cfgfile, weightfile, batch_size, dynamics)
4267
else:
4368
print('Please execute this script this way:\n')
4469
print(' python darknet2onnx.py <cfgFile> <weightFile>')

0 commit comments

Comments
 (0)