Skip to content

Commit 5d5f900

Browse files
author
me
committed
split onnx export into torch.export then torch.onnx.export with the "torch.export"-ed program. This is a workaround for a bug introduced in new release of torch version 2.10.0
1 parent ab068c9 commit 5d5f900

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/test.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def get_model(model: str, variant: str = ''):
250250
has_obj = False
251251

252252
elif model == 'yolov7':
253-
load_from_yolov7_official(net, '../weights/yolov7.pt')
253+
if os.path.exists('../weights/yolov7.pt'):
254+
load_from_yolov7_official(net, '../weights/yolov7.pt')
254255

255256
return net, has_obj
256257

@@ -272,12 +273,15 @@ def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
272273
_ = net(x) # warmup all the einops kernels
273274

274275
print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ...", bcolors.ENDC)
275-
torch.onnx.export(net, (x,), dynamo=True, opset_version=23,
276-
input_names=['img'],
277-
output_names=['preds'],
278-
dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)}).save(onnx_path)
276+
prog = torch.export.export(net, (x,), dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)})
279277
print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ... Done", bcolors.ENDC)
280278

279+
print(bcolors.OKGREEN, f"ONNX {type(net).__name__} ...", bcolors.ENDC)
280+
torch.onnx.export(prog, dynamo=True, opset_version=23,
281+
input_names=['img'],
282+
output_names=['preds']).save(onnx_path)
283+
print(bcolors.OKGREEN, f"ONNX {type(net).__name__} ... Done", bcolors.ENDC)
284+
281285
print(bcolors.OKGREEN, f"Slimming {type(net).__name__} ...", bcolors.ENDC)
282286
model = onnx.load(onnx_path)
283287
slimmed_model = onnxslim.slim(model)
@@ -369,3 +373,8 @@ def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
369373
# export('yolov12', 'm')
370374
# export('yolov12', 'l')
371375
# export('yolov12', 'x')
376+
# export('yolov26', 'n')
377+
# export('yolov26', 's')
378+
# export('yolov26', 'm')
379+
# export('yolov26', 'l')
380+
# export('yolov26', 'x')

src/unit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def test_export(net: YoloBase):
4343
net = net.eval()
4444
x = torch.randn(4, 3, 640, 640)
4545
_ = net(x) # compile einops kernels just in case
46-
torch.onnx.export(net, (x,), dynamo=True, opset_version=23,
46+
prog = torch.export.export(net, (x,), dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)})
47+
torch.onnx.export(prog, dynamo=True, opset_version=23,
4748
input_names=['img'],
48-
output_names=['preds'],
49-
dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)}).save('/tmp/model.onnx')
49+
output_names=['preds']).save('/tmp/model.onnx')
5050
netOrt = ort.InferenceSession('/tmp/model.onnx', providers=['CPUExecutionProvider'])
5151
x = torch.randn(2, 3, 576, 768)
5252
preds1 = net(x)

0 commit comments

Comments
 (0)