Skip to content

Commit ae3a9e7

Browse files
committed
aot script updated
1 parent 1d975c7 commit ae3a9e7

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

examples/openvino/aot_openvino_compiler.py renamed to examples/openvino/aot_optimize_and_infer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def main(
181181
suite: str,
182182
model_name: str,
183183
input_shape,
184+
save_model: bool,
185+
model_file_name: str,
184186
quantize: bool,
185187
validate: bool,
186188
dataset_path: str,
@@ -198,6 +200,8 @@ def main(
198200
:param suite: The model suite to use (e.g., "timm", "torchvision", "huggingface").
199201
:param model_name: The name of the model to load.
200202
:param input_shape: The input shape for the model.
203+
:param save_model: Whether to save the compiled model as a .pte file.
204+
:param model_file_name: Custom file name to save the exported model.
201205
:param quantize: Whether to quantize the model.
202206
:param validate: Whether to validate the model.
203207
:param dataset_path: Path to the dataset for calibration/validation.
@@ -264,10 +268,12 @@ def main(
264268
)
265269

266270
# Serialize and save it to a file
267-
model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
268-
with open(model_file_name, "wb") as file:
269-
exec_prog.write_to_file(file)
270-
print(f"Model exported and saved as {model_file_name} on {device}.")
271+
if save_model:
272+
if not model_file_name:
273+
model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}.pte"
274+
with open(model_file_name, "wb") as file:
275+
exec_prog.write_to_file(file)
276+
print(f"Model exported and saved as {model_file_name} on {device}.")
271277

272278
if validate:
273279
if suite == "huggingface":
@@ -315,6 +321,14 @@ def main(
315321
help="Batch size for the validation. Default batch_size == 1."
316322
" The dataset length must be evenly divisible by the batch size.",
317323
)
324+
parser.add_argument(
325+
"--export", action="store_true", help="Export the compiled model as .pte file."
326+
)
327+
parser.add_argument(
328+
"--model_file_name",
329+
type=str,
330+
help="Custom file name to save the exported model.",
331+
)
318332
parser.add_argument(
319333
"--quantize", action="store_true", help="Enable model quantization."
320334
)
@@ -367,6 +381,8 @@ def main(
367381
args.suite,
368382
args.model,
369383
args.input_shape,
384+
args.export,
385+
args.model_file_name,
370386
args.quantize,
371387
args.validate,
372388
args.dataset,

0 commit comments

Comments
 (0)