Skip to content

Commit 7757055

Browse files
author
Ted Themistokleous
committed
Add option for calibration EP data selection
1 parent 5240ddc commit 7757055

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,15 @@ def parse_input_args():
306306
help='The desired execution provider [MIGraphX, ROCm] are the options; Default is MIGraphX',
307307
)
308308

309+
parser.add_argument(
310+
"--cal_ep",
311+
action="store",
312+
required=False,
313+
default="MIGraphX",
314+
type=str,
315+
help='The desired execution provider [MIGraphX, ROCm] for int8 quantization; Default is MIGraphX',
316+
)
317+
309318
parser.add_argument(
310319
"--model",
311320
action="store",
@@ -359,9 +368,9 @@ def parse_input_args():
359368

360369
parser.add_argument(
361370
"--ort_quant",
362-
action="store_false",
371+
action="store_true",
363372
required=False,
364-
default="MIGraphX",
373+
default=False,
365374
help='Turn on Onnxruntime Quantizer instead of MIGraphX Quantizer',
366375
)
367376

@@ -461,6 +470,17 @@ def output_run_config(flags, samples):
461470
print("Error: EP:" + str(flags.ep) + " Invalid")
462471
exit
463472

473+
cal_ep = "MIGraphXExecutionProvider"
474+
if not flags.ort_quant and flags.int8:
475+
if flags.cal_ep == "MIGraphX":
476+
cal_ep = "MIGraphXExecutionProvider"
477+
elif flags.cal_ep == "ROCm":
478+
cal_ep = "ROCMExecutionProvider"
479+
else:
480+
print("Error: cal_ep:" + str(flags.cal_ep) + " Invalid")
481+
exit
482+
483+
464484
# Set squad version
465485
if flags.version == 1.1:
466486
squad_json = "./squad/dev-v1.1.json"
@@ -506,9 +526,9 @@ def output_run_config(flags, samples):
506526
model = onnx.load_model(model_path)
507527

508528
# Generate INT8 calibration cache
509-
print("Calibration starts ...")
529+
print("Calibration data compute starts with " + str(cal_ep))
510530
calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile)
511-
calibrator.set_execution_providers([ep])
531+
calibrator.set_execution_providers([cal_ep])
512532

513533
'''
514534
We can use one data reader to do data pre-processing, however,
@@ -538,7 +558,6 @@ def output_run_config(flags, samples):
538558

539559
if flags.ort_quant:
540560
print("Int8 Quantization Done with Onnxruntime Quantizer")
541-
# Generate QDQ model
542561
mode = QuantizationMode.QLinearOps
543562
# In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node.
544563
# Mirroring here what TRT does in MIGraphX Quantization to be able to perform an apples to apples comparison
@@ -561,7 +580,7 @@ def output_run_config(flags, samples):
561580
print("QDQ model is saved to ", qdq_model_path)
562581
else:
563582
qdq_model_path = model_path
564-
print("Int8 Quantization Done with " + ep)
583+
print("Int8 Quantization Done with " + cal_ep)
565584
#Quantize with MIGraphX's INT8 quantizer instead
566585
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision
567586
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name

0 commit comments

Comments
 (0)