Skip to content

Commit 9faea58

Browse files
author
me
committed
- use dynamo
- use onnxslim
1 parent 0dce609 commit 9faea58

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ torch
44
torchvision
55
einops
66
onnx
7+
onnxslim
78
onnxruntime
89
onnxscript>=0.5.7
910
lightning

src/test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from typing import Union
33
import numpy as np
44
import torch
5+
from torch.export import Dim
56
import torchvision
7+
import onnx
8+
import onnxslim
69
import onnxruntime as ort
710
from models import *
811

@@ -263,21 +266,26 @@ def test(model: str, variant: str = ''):
263266
torchvision.io.write_png(canvas, f'dog_{model}{variant}_output.png')
264267

265268

266-
def export(model: str, variant: str = ''):
269+
def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'):
267270
net, has_obj = get_model(model, variant)
268271
x = torch.randn(4, 3, 640, 640)
269272
_ = net(x) # warmup all the einops kernels
270273

271274
print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ...", bcolors.ENDC)
272-
torch.onnx.export(net, (x,), '/tmp/model.onnx', dynamo=False,
275+
torch.onnx.export(net, (x,), dynamo=True, opset_version=23,
273276
input_names=['img'],
274277
output_names=['preds'],
275-
dynamic_axes={'img' : {0: 'B', 2: 'H', 3: 'W'},
276-
'preds' : {0: 'B', 1: 'N'}})
278+
dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)}).save(onnx_path)
277279
print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ... Done", bcolors.ENDC)
278280

281+
print(bcolors.OKGREEN, f"Slimming {type(net).__name__} ...", bcolors.ENDC)
282+
model = onnx.load(onnx_path)
283+
slimmed_model = onnxslim.slim(model)
284+
onnx.save(slimmed_model, onnx_path)
285+
print(bcolors.OKGREEN, f"Slimming {type(net).__name__} ... Done", bcolors.ENDC)
286+
279287
print(bcolors.OKGREEN, "Checking with onnxruntime...", bcolors.ENDC)
280-
netOrt = ort.InferenceSession('/tmp/model.onnx', providers=['CPUExecutionProvider'])
288+
netOrt = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
281289
x = torch.randn(1, 3, 576, 768)
282290
preds1 = net(x)
283291
preds2, = netOrt.run(None, {'img': x.numpy()})

src/test_unit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import numpy as np
33
import torch
4+
from torch.export import Dim
45
import onnxruntime as ort
56
from models import *
67

@@ -41,11 +42,10 @@ def test_export(net: YoloBase):
4142
net = net.eval()
4243
x = torch.randn(1, 3, 320, 320)
4344
_ = net(x) # compile einops kernels just in case
44-
torch.onnx.export(net, (x,), '/tmp/model.onnx', dynamo=False,
45-
input_names=['img'],
46-
output_names=['preds'],
47-
dynamic_axes={'img' : {0: 'B', 2: 'H', 3: 'W'},
48-
'preds' : {0: 'B', 1: 'N'}})
45+
torch.onnx.export(net, (x,), dynamo=True, opset_version=23,
46+
input_names=['img'],
47+
output_names=['preds'],
48+
dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)}).save('/tmp/model.onnx')
4949
netOrt = ort.InferenceSession('/tmp/model.onnx', providers=['CPUExecutionProvider'])
5050
x = torch.randn(2, 3, 576, 768)
5151
preds1 = net(x)

0 commit comments

Comments
 (0)