Skip to content

Commit b872ebb

Browse files
openvino_executor_runner.cpp can run on several inputs
1 parent 29bc381 commit b872ebb

File tree

2 files changed

+214
-182
lines changed

2 files changed

+214
-182
lines changed

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
# directory of this source tree for more details.
66

77
import argparse
8+
import os
9+
import shutil
10+
import subprocess
11+
from pathlib import Path
812

913
import executorch
14+
import numpy as np
1015
import timm
1116
import torch
1217
import torchvision.datasets as datasets
@@ -19,9 +24,9 @@
1924
from sklearn.metrics import accuracy_score
2025
from timm.data import resolve_data_config
2126
from timm.data.transforms_factory import create_transform
22-
from torch.export import ExportedProgram
2327
from torch.export import export
2428
from torch.export.exported_program import ExportedProgram
29+
from torch.fx.passes.graph_drawer import FxGraphDrawer
2530
from transformers import AutoModel
2631

2732
import nncf
@@ -36,12 +41,14 @@ def load_model(suite: str, model_name: str):
3641
return timm.create_model(model_name, pretrained=True)
3742
elif suite == "torchvision":
3843
if not hasattr(torchvision_models, model_name):
39-
raise ValueError(f"Model {model_name} not found in torchvision.")
44+
msg = f"Model {model_name} not found in torchvision."
45+
raise ValueError(msg)
4046
return getattr(torchvision_models, model_name)(pretrained=True)
4147
elif suite == "huggingface":
4248
return AutoModel.from_pretrained(model_name)
4349
else:
44-
raise ValueError(f"Unsupported model suite: {suite}")
50+
msg = f"Unsupported model suite: {suite}"
51+
raise ValueError(msg)
4552

4653

4754
def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Module):
@@ -61,12 +68,32 @@ def load_calibration_dataset(dataset_path: str, suite: str, model: torch.nn.Modu
6168
return calibration_dataset
6269

6370

71+
def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
72+
g = FxGraphDrawer(model, output_svg_path)
73+
g.get_dot_graph().write_svg(output_svg_path)
74+
75+
76+
def dump_inputs(calibration_dataset, dest_path):
77+
input_files, targets = [], []
78+
for idx, data in enumerate(calibration_dataset):
79+
feature, target = data
80+
targets.append(target)
81+
file_name = f"{dest_path}/input_{idx}_0.raw"
82+
if not isinstance(feature, torch.Tensor):
83+
feature = torch.tensor(feature)
84+
feature.detach().numpy().tofile(file_name)
85+
input_files.append(file_name)
86+
87+
return input_files, targets
88+
89+
6490
def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path: str, device: str):
6591
# Ensure input_shape is a tuple
6692
if isinstance(input_shape, list):
6793
input_shape = tuple(input_shape)
6894
elif not isinstance(input_shape, tuple):
69-
raise ValueError("Input shape must be a list or tuple.")
95+
msg = "Input shape must be a list or tuple."
96+
raise ValueError(msg)
7097

7198
# Load the selected model
7299
model = load_model(suite, model_name)
@@ -80,11 +107,13 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
80107

81108
if quantize:
82109
if suite == "huggingface":
83-
raise ValueError("Quantization of {suite} models did not support yet.")
110+
msg = f"Quantization of {suite} models did not support yet."
111+
raise ValueError(msg)
84112

85113
# Quantize model
86114
if not dataset_path:
87-
raise ValueError("Quantization requires a calibration dataset.")
115+
msg = "Quantization requires a calibration dataset."
116+
raise ValueError(msg)
88117
calibration_dataset = load_calibration_dataset(dataset_path, suite, model)
89118

90119
captured_model = aten_dialect.module()
@@ -101,6 +130,7 @@ def transform(x):
101130
calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform),
102131
fold_quantize=False,
103132
)
133+
visualize_fx_model(quantized_model, f"{model_name}_int8.svg")
104134

105135
aten_dialect: ExportedProgram = export(quantized_model, example_args)
106136

@@ -123,54 +153,37 @@ def transform(x):
123153

124154
if quantize:
125155
print("Start validation of the quantized model:")
126-
127156
# 1: Dump inputs
128-
import os
129-
import shutil
130-
131-
dest_path = "tmp_inputs"
132-
out_path = "tmp_outputs"
133-
targets, input_files = [], []
157+
dest_path = Path("tmp_inputs")
158+
out_path = Path("tmp_outputs")
134159
for d in [dest_path, out_path]:
135160
if os.path.exists(d):
136161
shutil.rmtree(d)
137162
os.makedirs(d)
138-
input_list = ""
139-
for idx, data in enumerate(calibration_dataset):
140-
feature, target = data
141-
targets.append(target)
142-
file_name = f"{dest_path}/input_{idx}_0.raw"
143-
input_list += file_name + " "
144-
if not isinstance(feature, torch.Tensor):
145-
feature = torch.tensor(feature)
146-
feature.detach().numpy().tofile(file_name)
147-
input_files.append(file_name)
148-
149-
inp_list_file = os.path.join(dest_path, "in_list.txt")
163+
164+
input_files, targets = dump_inputs(calibration_dataset, dest_path)
165+
inp_list_file = dest_path / "in_list.txt"
150166
with open(inp_list_file, "w") as f:
151-
input_list = input_list.strip() + "\n"
152-
f.write(input_list)
167+
f.write("\n".join(input_files) + "\n")
153168

154169
# 2: Run the executor
155170
print("Run openvino_executor_runner...")
156-
import subprocess
157171

158172
subprocess.run(
159173
[
160174
"../../../cmake-openvino-out/examples/openvino/openvino_executor_runner",
161175
f"--model_path={model_name}",
162176
f"--input_list_path={inp_list_file}",
163177
f"--output_folder_path={out_path}",
164-
# f"--num_iter={len(input_files)}"
165178
]
166179
)
167180

168181
# 3: load the outputs and compare with the targets
169-
import numpy as np
170182

171183
predictions = []
172184
for i in range(len(input_files)):
173-
predictions.append(np.fromfile(os.path.join(out_path, f"output_{i}.raw"), dtype=np.float32))
185+
tensor = np.fromfile(out_path / f"output_{i}_0.raw", dtype=np.float32)
186+
predictions.append(torch.tensor(np.argmax(tensor)))
174187

175188
acc_top1 = accuracy_score(predictions, targets)
176189
print(f"acc@1: {acc_top1}")

0 commit comments

Comments
 (0)