Skip to content

Commit 09814cf

Browse files
author
Ted Themistokleous
committed
Additional Fixes for save_load and adding CPU EP option
1 parent 7757055 commit 09814cf

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def parse_input_args():
312312
required=False,
313313
default="MIGraphX",
314314
type=str,
315-
help='The desired execution provider [MIGraphX, ROCm] for int8 quantization; Default is MIGraphX',
315+
help='The desired execution provider [MIGraphX, ROCm, CPU] for int8 quantization; Default is MIGraphX',
316316
)
317317

318318
parser.add_argument(
@@ -376,7 +376,7 @@ def parse_input_args():
376376

377377
parser.add_argument(
378378
"--save_load",
379-
action="store_false",
379+
action="store_true",
380380
required=False,
381381
default=False,
382382
help='Turn on Onnxruntime Model save loading to speed up inference',
@@ -471,11 +471,13 @@ def output_run_config(flags, samples):
471471
exit
472472

473473
cal_ep = "MIGraphXExecutionProvider"
474-
if not flags.ort_quant and flags.int8:
474+
if flags.int8:
475475
if flags.cal_ep == "MIGraphX":
476476
cal_ep = "MIGraphXExecutionProvider"
477477
elif flags.cal_ep == "ROCm":
478478
cal_ep = "ROCMExecutionProvider"
479+
elif flags.cal_ep == "CPU":
480+
cal_ep = "CPUExecutionProvider"
479481
else:
480482
print("Error: cal_ep:" + str(flags.cal_ep) + " Invalid")
481483
exit
@@ -597,10 +599,12 @@ def output_run_config(flags, samples):
597599
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" # Disable MIGRAPHX FP16 precision
598600

599601
if flags.save_load:
602+
model_name = str(qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr"
603+
print("save load model from " + str(model_name))
600604
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL"] = "1"
601605
os.environ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL"] = "1"
602-
os.environ["ORT_MIGRAPHX_SAVE_COMPILE_PATH"] = (qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + (model_quants) + ".mxr"
603-
os.environ["ORT_MIGRAPHX_LOAD_COMPILE_PATH"] = (qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr"
606+
os.environ["ORT_MIGRAPHX_SAVE_COMPILE_PATH"] = model_name
607+
os.environ["ORT_MIGRAPHX_LOAD_COMPILE_PATH"] = model_name
604608

605609
# QDQ model inference and get SQUAD prediction
606610
batch_size = flags.batch

0 commit comments

Comments
 (0)