Skip to content

Commit 4eee40d

Browse files
authored
Merge pull request #31 from cavusmustafa/additional_updates
Additional Updates on Export and Infer Scripts
2 parents 3e17d09 + a02855f commit 4eee40d

File tree

4 files changed

+160
-286
lines changed

4 files changed

+160
-286
lines changed

backends/openvino/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ executorch
3232
│ └── requirements.txt
3333
└── examples
3434
└── openvino
35-
├── aot_openvino_compiler.py
36-
├── export_and_infer_openvino.py
35+
├── aot_optimize_and_infer.py
3736
└── README.md
3837
```
3938

examples/openvino/README.md

Lines changed: 39 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ Below is the layout of the `examples/openvino` directory, which includes the nec
99
```
1010
examples/openvino
1111
├── README.md # Documentation for examples (this file)
12-
├── aot_openvino_compiler.py # Example script for AoT export
13-
└── export_and_infer_openvino.py # Example script to export and execute models with python bindings
12+
└── aot_optimize_and_infer.py # Example script to export and execute models
1413
```
1514

1615
# Build Instructions for Examples
@@ -20,14 +19,10 @@ Follow the [instructions](../../backends/openvino/README.md) of **Prerequisites*
2019

2120
## AOT step:
2221

23-
The export script called `aot_openvino_compiler.py` allows users to export deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to a openvino backend using **Executorch**. Users can dynamically specify the model, input shape, and target device.
22+
The python script called `aot_optimize_and_infer.py` allows users to export deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to a openvino backend using **Executorch**. Users can dynamically specify the model, input shape, and target device.
2423

2524
### **Usage**
2625

27-
#### **Command Structure**
28-
```bash
29-
python aot_openvino_compiler.py --suite <MODEL_SUITE> --model <MODEL_NAME> --input_shape <INPUT_SHAPE> --device <DEVICE>
30-
```
3126

3227
#### **Arguments**
3328
- **`--suite`** (required):
@@ -50,6 +45,12 @@ python aot_openvino_compiler.py --suite <MODEL_SUITE> --model <MODEL_NAME> --inp
5045
- `[1, 3, 224, 224]` (Zsh users: wrap in quotes)
5146
- `(1, 3, 224, 224)`
5247

48+
- **`--export`** (optional):
49+
Save the exported model as a `.pte` file.
50+
51+
- **`--model_file_name`** (optional):
52+
Specify a custom file name to save the exported model.
53+
5354
- **`--batch_size`** :
5455
Batch size for the validation. Default batch_size == 1.
5556
The dataset length must be evenly divisible by the batch size.
@@ -63,35 +64,55 @@ python aot_openvino_compiler.py --suite <MODEL_SUITE> --model <MODEL_NAME> --inp
6364
- **`--dataset`** (optional):
6465
Path to the imagenet-like calibration dataset.
6566

67+
- **`--infer`** (optional):
68+
Execute inference with the compiled model and report average inference timing.
69+
70+
- **`--num_iter`** (optional):
71+
Number of iterations to execute inference. Default value for the number of iterations is `1`.
72+
73+
- **`--warmup_iter`** (optional):
74+
Number of warmup iterations to execute inference before timing begins. Default value for the warmup iterations is `0`.
75+
76+
- **`--input_tensor_path`** (optional):
77+
Path to the raw tensor file to be used as input for inference. If this argument is not provided, a random input tensor will be generated.
78+
79+
- **`--output_tensor_path`** (optional):
80+
Path to the raw tensor file which the output of the inference to be saved.
81+
6682
- **`--device`** (optional)
6783
Target device for the compiled model. Default is `CPU`.
6884
Examples: `CPU`, `GPU`
6985

7086

71-
### **Examples**
87+
#### **Examples**
7288

73-
#### Export a TIMM VGG16 model for the CPU
89+
##### Export a TIMM VGG16 model for the CPU
7490
```bash
75-
python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU
91+
python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU
7692
```
7793

78-
#### Export a Torchvision ResNet50 model for the GPU
94+
##### Export a Torchvision ResNet50 model for the GPU
7995
```bash
80-
python aot_openvino_compiler.py --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU
96+
python aot_optimize_and_infer.py --export --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU
8197
```
8298

83-
#### Export a Hugging Face BERT model for the CPU
99+
##### Export a Hugging Face BERT model for the CPU
100+
```bash
101+
python aot_optimize_and_infer.py --export --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU
102+
```
103+
##### Export and validate TIMM Resnet50d model for the CPU
84104
```bash
85-
python aot_openvino_compiler.py --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU
105+
python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset
86106
```
87-
#### Export and validate TIMM Resnet50d model for the CPU
107+
108+
##### Export, quantize and validate TIMM Resnet50d model for the CPU
88109
```bash
89-
python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset
110+
python aot_optimize_and_infer.py --export --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset --quantize
90111
```
91112

92-
#### Export, quantize and validate TIMM Resnet50d model for the CPU
113+
##### Execute Inference with Torchvision Inception V3 model for the CPU
93114
```bash
94-
python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU --validate --dataset /path/to/dataset --quantize
115+
python aot_optimize_and_infer.py --suite torchvision --model inception_v3 --infer --warmup_iter 10 --num_iter 100 --input_shape "(1, 3, 256, 256)" --device CPU
95116
```
96117

97118
### **Notes**
@@ -162,72 +183,3 @@ Run inference with a given model for 10 iterations:
162183
--model_path=model.pte \
163184
--num_executions=10
164185
```
165-
166-
## Running Python Example with Pybinding:
167-
168-
You can use the `export_and_infer_openvino.py` script to run models with the OpenVINO backend through the Python bindings.
169-
170-
### **Usage**
171-
172-
#### **Command Structure**
173-
```bash
174-
python export_and_infer_openvino.py <ARGUMENTS>
175-
```
176-
177-
#### **Arguments**
178-
- **`--suite`** (required if `--model_path` argument is not used):
179-
Specifies the model suite to use. Needs to be used with `--model` argument.
180-
Supported values:
181-
- `timm` (e.g., VGG16, ResNet50)
182-
- `torchvision` (e.g., resnet18, mobilenet_v2)
183-
- `huggingface` (e.g., bert-base-uncased). NB: Quantization and validation is not supported yet.
184-
185-
- **`--model`** (required if `--model_path` argument is not used):
186-
Name of the model to export. Needs to be used with `--suite` argument.
187-
Examples:
188-
- For `timm`: `vgg16`, `resnet50`
189-
- For `torchvision`: `resnet18`, `mobilenet_v2`
190-
- For `huggingface`: `bert-base-uncased`, `distilbert-base-uncased`
191-
192-
- **`--model_path`** (required if `--suite` and `--model` arguments are not used):
193-
Path to the saved model file. This argument allows you to load the compiled model from a file, instead of downloading it from the model suites using the `--suite` and `--model` arguments.
194-
Example: `<path to model foler>/resnet50_fp32.pte`
195-
196-
- **`--input_shape`**(required for random inputs):
197-
Input shape for the model. Provide this as a **list** or **tuple**.
198-
Examples:
199-
- `[1, 3, 224, 224]` (Zsh users: wrap in quotes)
200-
- `(1, 3, 224, 224)`
201-
202-
- **`--input_tensor_path`**(optional):
203-
Path to the raw input tensor file. If this argument is not provided, a random input tensor will be generated with the input shape provided with `--input_shape` argument.
204-
Example: `<path to the input tensor foler>/input_tensor.pt`
205-
206-
- **`--output_tensor_path`**(optional):
207-
Path to the file where the output raw tensor will be saved.
208-
Example: `<path to the output tensor foler>/output_tensor.pt`
209-
210-
- **`--device`** (optional)
211-
Target device for the compiled model. Default is `CPU`.
212-
Examples: `CPU`, `GPU`
213-
214-
- **`--num_iter`** (optional)
215-
Number of iterations to execute inference for evaluation. The default value is `1`.
216-
Examples: `100`, `1000`
217-
218-
- **`--warmup_iter`** (optional)
219-
Number of warmup iterations to execute inference before evaluation. The default value is `0`.
220-
Examples: `5`, `10`
221-
222-
223-
### **Examples**
224-
225-
#### Execute Torchvision ResNet50 model for the GPU with Random Inputs
226-
```bash
227-
python export_and_infer_openvino.py --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU
228-
```
229-
230-
#### Run a Precompiled Model for the CPU Using an Existing Input Tensor File and Save the Output.
231-
```bash
232-
python export_and_infer_openvino.py --model_path /path/to/model/folder/resnet50_fp32.pte --input_tensor_file /path/to/input/folder/input.pt --output_tensor_file /path/to/output/folder/output.pt --device CPU
233-
```

examples/openvino/aot_openvino_compiler.py renamed to examples/openvino/aot_optimize_and_infer.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8+
import time
89

910
import executorch
1011

@@ -102,6 +103,54 @@ def load_calibration_dataset(
102103
return calibration_dataset
103104

104105

106+
def infer_model(
107+
exec_prog: EdgeProgramManager,
108+
input_shape,
109+
num_iter: int,
110+
warmup_iter: int,
111+
input_path: str,
112+
output_path: str,
113+
) -> float:
114+
"""
115+
Executes inference and reports the average timing.
116+
117+
:param exec_prog: EdgeProgramManager of the lowered model
118+
:param input_shape: The input shape for the model.
119+
:param num_iter: The number of iterations to execute inference for timing.
120+
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
121+
:param input_path: Path to the input tensor file to read the input for inference.
122+
:param output_path: Path to the output tensor file to save the output of inference..
123+
:return: The average inference timing.
124+
"""
125+
# 1: Load model from buffer
126+
executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
127+
128+
# 2: Initialize inputs
129+
if input_path:
130+
inputs = (torch.load(input_path, weights_only=False),)
131+
else:
132+
inputs = (torch.randn(input_shape),)
133+
134+
# 3: Execute warmup
135+
for _i in range(warmup_iter):
136+
out = executorch_module.run_method("forward", inputs)
137+
138+
# 4: Execute inference and measure timing
139+
time_total = 0.0
140+
for _i in range(num_iter):
141+
time_start = time.time()
142+
out = executorch_module.run_method("forward", inputs)
143+
time_end = time.time()
144+
time_total += time_end - time_start
145+
146+
# 5: Save output tensor as raw tensor file
147+
if output_path:
148+
torch.save(out, output_path)
149+
150+
# 6: Return average inference timing
151+
return time_total / float(num_iter)
152+
153+
105154
def validate_model(
106155
exec_prog: EdgeProgramManager, calibration_dataset: torch.utils.data.DataLoader
107156
) -> float:
@@ -128,27 +177,42 @@ def validate_model(
128177
return accuracy_score(predictions, targets)
129178

130179

131-
def main(
180+
def main( # noqa: C901
132181
suite: str,
133182
model_name: str,
134183
input_shape,
184+
save_model: bool,
185+
model_file_name: str,
135186
quantize: bool,
136187
validate: bool,
137188
dataset_path: str,
138189
device: str,
139190
batch_size: int,
191+
infer: bool,
192+
num_iter: int,
193+
warmup_iter: int,
194+
input_path: str,
195+
output_path: str,
140196
):
141197
"""
142198
Main function to load, quantize, and validate a model.
143199
144200
:param suite: The model suite to use (e.g., "timm", "torchvision", "huggingface").
145201
:param model_name: The name of the model to load.
146202
:param input_shape: The input shape for the model.
203+
:param save_model: Whether to save the compiled model as a .pte file.
204+
:param model_file_name: Custom file name to save the exported model.
147205
:param quantize: Whether to quantize the model.
148206
:param validate: Whether to validate the model.
149207
:param dataset_path: Path to the dataset for calibration/validation.
150208
:param device: The device to run the model on (e.g., "cpu", "gpu").
151209
:param batch_size: Batch size for dataset loading.
210+
:param infer: Whether to execute inference and report timing.
211+
:param num_iter: The number of iterations to execute inference for timing.
212+
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
213+
:param input_path: Path to the input tensor file to read the input for inference.
214+
:param output_path: Path to the output tensor file to save the output of inference..
215+
152216
"""
153217

154218
# Load the selected model
@@ -214,10 +278,12 @@ def transform_fn(x):
214278
)
215279

216280
# Serialize and save it to a file
217-
model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
218-
with open(model_file_name, "wb") as file:
219-
exec_prog.write_to_file(file)
220-
print(f"Model exported and saved as {model_file_name} on {device}.")
281+
if save_model:
282+
if not model_file_name:
283+
model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
284+
with open(model_file_name, "wb") as file:
285+
exec_prog.write_to_file(file)
286+
print(f"Model exported and saved as {model_file_name} on {device}.")
221287

222288
if validate:
223289
if suite == "huggingface":
@@ -232,6 +298,13 @@ def transform_fn(x):
232298
acc_top1 = validate_model(exec_prog, calibration_dataset)
233299
print(f"acc@1: {acc_top1}")
234300

301+
if infer:
302+
print("Start inference of the model:")
303+
avg_time = infer_model(
304+
exec_prog, input_shape, num_iter, warmup_iter, input_path, output_path
305+
)
306+
print(f"Average inference time: {avg_time}")
307+
235308

236309
if __name__ == "__main__":
237310
# Argument parser for dynamic inputs
@@ -258,6 +331,14 @@ def transform_fn(x):
258331
help="Batch size for the validation. Default batch_size == 1."
259332
" The dataset length must be evenly divisible by the batch size.",
260333
)
334+
parser.add_argument(
335+
"--export", action="store_true", help="Export the compiled model as .pte file."
336+
)
337+
parser.add_argument(
338+
"--model_file_name",
339+
type=str,
340+
help="Custom file name to save the exported model.",
341+
)
261342
parser.add_argument(
262343
"--quantize", action="store_true", help="Enable model quantization."
263344
)
@@ -266,6 +347,33 @@ def transform_fn(x):
266347
action="store_true",
267348
help="Enable model validation. --dataset argument is required for the validation.",
268349
)
350+
parser.add_argument(
351+
"--infer",
352+
action="store_true",
353+
help="Run inference and report timing.",
354+
)
355+
parser.add_argument(
356+
"--num_iter",
357+
type=int,
358+
default=1,
359+
help="The number of iterations to execute inference for timing.",
360+
)
361+
parser.add_argument(
362+
"--warmup_iter",
363+
type=int,
364+
default=0,
365+
help="The number of iterations to execute inference for warmup before timing.",
366+
)
367+
parser.add_argument(
368+
"--input_tensor_path",
369+
type=str,
370+
help="Path to the input tensor file to read the input for inference.",
371+
)
372+
parser.add_argument(
373+
"--output_tensor_path",
374+
type=str,
375+
help="Path to the output tensor file to save the output of inference.",
376+
)
269377
parser.add_argument("--dataset", type=str, help="Path to the validation dataset.")
270378
parser.add_argument(
271379
"--device",
@@ -283,9 +391,16 @@ def transform_fn(x):
283391
args.suite,
284392
args.model,
285393
args.input_shape,
394+
args.export,
395+
args.model_file_name,
286396
args.quantize,
287397
args.validate,
288398
args.dataset,
289399
args.device,
290400
args.batch_size,
401+
args.infer,
402+
args.num_iter,
403+
args.warmup_iter,
404+
args.input_tensor_path,
405+
args.output_tensor_path,
291406
)

0 commit comments

Comments
 (0)