|
2 | 2 | from typing import Union |
3 | 3 | import numpy as np |
4 | 4 | import torch |
| 5 | +from torch.export import Dim |
5 | 6 | import torchvision |
| 7 | +import onnx |
| 8 | +import onnxslim |
6 | 9 | import onnxruntime as ort |
7 | 10 | from models import * |
8 | 11 |
|
@@ -263,21 +266,26 @@ def test(model: str, variant: str = ''): |
263 | 266 | torchvision.io.write_png(canvas, f'dog_{model}{variant}_output.png') |
264 | 267 |
|
265 | 268 |
|
266 | | -def export(model: str, variant: str = ''): |
| 269 | +def export(model: str, variant: str = '', onnx_path:str = '/tmp/model.onnx'): |
267 | 270 | net, has_obj = get_model(model, variant) |
268 | 271 | x = torch.randn(4, 3, 640, 640) |
269 | 272 | _ = net(x) # warmup all the einops kernels |
270 | 273 |
|
271 | 274 | print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ...", bcolors.ENDC) |
272 | | - torch.onnx.export(net, (x,), '/tmp/model.onnx', dynamo=False, |
| 275 | + torch.onnx.export(net, (x,), dynamo=True, opset_version=23, |
273 | 276 | input_names=['img'], |
274 | 277 | output_names=['preds'], |
275 | | - dynamic_axes={'img' : {0: 'B', 2: 'H', 3: 'W'}, |
276 | | - 'preds' : {0: 'B', 1: 'N'}}) |
| 278 | + dynamic_shapes={'x' : (Dim.DYNAMIC, Dim.STATIC, Dim.DYNAMIC, Dim.DYNAMIC)}).save(onnx_path) |
277 | 279 | print(bcolors.OKGREEN, f"Exporting {type(net).__name__} ... Done", bcolors.ENDC) |
278 | 280 |
|
| 281 | + print(bcolors.OKGREEN, f"Slimming {type(net).__name__} ...", bcolors.ENDC) |
| 282 | + model = onnx.load(onnx_path) |
| 283 | + slimmed_model = onnxslim.slim(model) |
| 284 | + onnx.save(slimmed_model, onnx_path) |
| 285 | + print(bcolors.OKGREEN, f"Slimming {type(net).__name__} ... Done", bcolors.ENDC) |
| 286 | + |
279 | 287 | print(bcolors.OKGREEN, "Checking with onnxruntime...", bcolors.ENDC) |
280 | | - netOrt = ort.InferenceSession('/tmp/model.onnx', providers=['CPUExecutionProvider']) |
| 288 | + netOrt = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) |
281 | 289 | x = torch.randn(1, 3, 576, 768) |
282 | 290 | preds1 = net(x) |
283 | 291 | preds2, = netOrt.run(None, {'img': x.numpy()}) |
|
0 commit comments