@@ -306,6 +306,15 @@ def parse_input_args():
306306 help = 'The desired execution provider [MIGraphX, ROCm] are the options; Default is MIGraphX' ,
307307 )
308308
309+ parser .add_argument (
310+ "--cal_ep" ,
311+ action = "store" ,
312+ required = False ,
313+ default = "MIGraphX" ,
314+ type = str ,
315+ help = 'The desired execution provider [MIGraphX, ROCm] for int8 quantization; Default is MIGraphX' ,
316+ )
317+
309318 parser .add_argument (
310319 "--model" ,
311320 action = "store" ,
@@ -359,9 +368,9 @@ def parse_input_args():
359368
360369 parser .add_argument (
361370 "--ort_quant" ,
362- action = "store_false " ,
371+ action = "store_true " ,
363372 required = False ,
364- default = "MIGraphX" ,
373+ default = False ,
365374 help = 'Turn on Onnxruntime Quantizer instead of MIGraphX Quantizer' ,
366375 )
367376
@@ -461,6 +470,17 @@ def output_run_config(flags, samples):
461470 print ("Error: EP:" + str (flags .ep ) + " Invalid" )
462471 exit
463472
473+ cal_ep = "MIGraphXExecutionProvider"
474+ if not flags .ort_quant and flags .int8 :
475+ if flags .cal_ep == "MIGraphX" :
476+ cal_ep = "MIGraphXExecutionProvider"
477+ elif flags .cal_ep == "ROCm" :
478+ cal_ep = "ROCMExecutionProvider"
479+ else :
480+ print ("Error: cal_ep:" + str (flags .cal_ep ) + " Invalid" )
481+ exit
482+
483+
464484 # Set squad version
465485 if flags .version == 1.1 :
466486 squad_json = "./squad/dev-v1.1.json"
@@ -506,9 +526,9 @@ def output_run_config(flags, samples):
506526 model = onnx .load_model (model_path )
507527
508528 # Generate INT8 calibration cache
509- print ("Calibration starts ..." )
529+ print ("Calibration data compute starts with " + str ( cal_ep ) )
510530 calibrator = create_calibrator (model_path , op_types_to_quantize , augmented_model_path = augmented_model_path , calibrate_method = CalibrationMethod .Percentile )
511- calibrator .set_execution_providers ([ep ])
531+ calibrator .set_execution_providers ([cal_ep ])
512532
513533 '''
514534 We can use one data reader to do data pre-processing, however,
@@ -538,7 +558,6 @@ def output_run_config(flags, samples):
538558
539559 if flags .ort_quant :
540560 print ("Int8 Quantization Done with Onnxruntime Quantizer" )
541- # Generate QDQ model
542561 mode = QuantizationMode .QLinearOps
543562 # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node.
544563 # Mirroring here what TRT does in MIGraphX Quantization to be able to perform an apples to apples comparison
@@ -561,7 +580,7 @@ def output_run_config(flags, samples):
561580 print ("QDQ model is saved to " , qdq_model_path )
562581 else :
563582 qdq_model_path = model_path
564- print ("Int8 Quantization Done with " + ep )
583+ print ("Int8 Quantization Done with " + cal_ep )
565584 #Quantize with MIGraphX's INT8 quantizer instead
566585 os .environ ["ORT_MIGRAPHX_INT8_ENABLE" ] = "1" # Enable MIGRAPHX INT8 precision
567586 os .environ ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME" ] = "calibration.flatbuffers" # Calibration table name
0 commit comments