Skip to content

Commit e55c258

Browse files
test
1 parent dded344 commit e55c258

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/post_training/pipelines/image_classification_timm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313
import onnx
14+
import onnxoptimizer
1415
import openvino as ov
1516
import timm
1617
import torch
@@ -45,17 +46,22 @@ def prepare_model(self) -> None:
4546
onnx_path = self.fp32_model_dir / "model_fp32.onnx"
4647
additional_kwargs = {}
4748
if self.batch_size > 1:
48-
additional_kwargs["input_names"] = ["image"]
49-
additional_kwargs["dynamic_axes"] = {"image": {0: "batch"}}
49+
batch = torch.export.Dim("batch")
50+
additional_kwargs["dynamic_shapes"] = ({0: batch},)
51+
5052
torch.onnx.export(
5153
timm_model,
5254
self.dummy_tensor,
5355
onnx_path,
5456
export_params=True,
5557
opset_version=13,
58+
dynamo=True,
5659
**additional_kwargs,
5760
)
58-
self.model = onnx.load(onnx_path)
61+
62+
model = onnx.load(onnx_path)
63+
passes = ["fuse_bn_into_conv"]
64+
self.model = onnxoptimizer.optimize(model, passes)
5965
self.input_name = self.model.graph.input[0].name
6066

6167
if self.backend in OV_BACKENDS + [BackendType.FP32]:

tests/post_training/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ transformers==4.53.0
2323
whowhatbench @ git+https://github.com/openvinotoolkit/[email protected]#subdirectory=tools/who_what_benchmark
2424
datasets==3.6.0
2525
onnxscript==0.5.4
26+
onnxoptimizer==0.3.8

0 commit comments

Comments
 (0)