diff --git a/tools/deployment/export_onnx.py b/tools/deployment/export_onnx.py index 4d60ae7..58fa807 100644 --- a/tools/deployment/export_onnx.py +++ b/tools/deployment/export_onnx.py @@ -51,8 +51,8 @@ def forward(self, images, orig_target_sizes): model = Model() - data = torch.rand(32, 3, 640, 640) - size = torch.tensor([[640, 640]]) + data = torch.rand(args.batch_size, 3, args.imgsz, args.imgsz) + size = torch.tensor([[args.imgsz, args.imgsz]]) _ = model(data, size) dynamic_axes = { @@ -69,7 +69,7 @@ def forward(self, images, orig_target_sizes): input_names=['images', 'orig_target_sizes'], output_names=['labels', 'boxes', 'scores'], dynamic_axes=dynamic_axes, - opset_version=16, + opset_version=args.opset, verbose=False, do_constant_folding=True, ) @@ -99,5 +99,8 @@ def forward(self, images, orig_target_sizes): parser.add_argument('--resume', '-r', type=str, ) parser.add_argument('--check', action='store_true', default=True,) parser.add_argument('--simplify', action='store_true', default=True,) + parser.add_argument('--opset', type=int, default=16,) + parser.add_argument('--imgsz', type=int, default=640,) + parser.add_argument('--batch_size', type=int, default=32,) args = parser.parse_args() main(args)