@@ -312,7 +312,7 @@ def parse_input_args():
312312 required = False ,
313313 default = "MIGraphX" ,
314314 type = str ,
315- help = 'The desired execution provider [MIGraphX, ROCm] for int8 quantization; Default is MIGraphX' ,
315+ help = 'The desired execution provider [MIGraphX, ROCm, CPU ] for int8 quantization; Default is MIGraphX' ,
316316 )
317317
318318 parser .add_argument (
@@ -376,7 +376,7 @@ def parse_input_args():
376376
377377 parser .add_argument (
378378 "--save_load" ,
379- action = "store_false " ,
379+ action = "store_true " ,
380380 required = False ,
381381 default = False ,
382382 help = 'Turn on Onnxruntime Model save loading to speed up inference' ,
@@ -471,11 +471,13 @@ def output_run_config(flags, samples):
471471 exit
472472
473473 cal_ep = "MIGraphXExecutionProvider"
474- if not flags . ort_quant and flags .int8 :
474+ if flags .int8 :
475475 if flags .cal_ep == "MIGraphX" :
476476 cal_ep = "MIGraphXExecutionProvider"
477477 elif flags .cal_ep == "ROCm" :
478478 cal_ep = "ROCMExecutionProvider"
479+ elif flags .cal_ep == "CPU" :
480+ cal_ep = "CPUExecutionProvider"
479481 else :
480482 print ("Error: cal_ep:" + str (flags .cal_ep ) + " Invalid" )
481483 exit
@@ -597,10 +599,12 @@ def output_run_config(flags, samples):
597599 os .environ ["ORT_MIGRAPHX_FP16_ENABLE" ] = "0" # Disable MIGRAPHX FP16 precision
598600
599601 if flags .save_load :
602+ model_name = str (qdq_model_path ) + "_s" + str (flags .seq_len ) + "_b" + str (flags .batch ) + str (model_quants ) + ".mxr"
603+ print ("save load model from " + str (model_name ))
600604 os .environ ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL" ] = "1"
601605 os .environ ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL" ] = "1"
602- os .environ ["ORT_MIGRAPHX_SAVE_COMPILE_PATH" ] = ( qdq_model_path ) + "_s" + str ( flags . seq_len ) + "_b" + str ( flags . batch ) + ( model_quants ) + ".mxr"
603- os .environ ["ORT_MIGRAPHX_LOAD_COMPILE_PATH" ] = ( qdq_model_path ) + "_s" + str ( flags . seq_len ) + "_b" + str ( flags . batch ) + str ( model_quants ) + ".mxr"
606+ os .environ ["ORT_MIGRAPHX_SAVE_COMPILE_PATH" ] = model_name
607+ os .environ ["ORT_MIGRAPHX_LOAD_COMPILE_PATH" ] = model_name
604608
605609 # QDQ model inference and get SQUAD prediction
606610 batch_size = flags .batch
0 commit comments