Skip to content

Commit fcbdc31

Browse files
authored
[NVBUG: 5373030] Disable the weight adjustment for int32 bias from onnxruntime (#510)
## What does this PR do? **Type of change:** Bug Fix **Overview:** - Disable the weight adjustment for int32 bias in onnxruntime by default ## Usage ```python python -m modelopt.onnx.quantization --onnx_path=code031_gemm_batch.onnx --simplify --calibration_eps trt --quantize_mode fp8 --disable_mha_qdq --high_precision_dtype fp16 ``` ## Testing Able to quantize the code031_gemm_batch.onnx model ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No --------- Signed-off-by: ajrasane <[email protected]>
1 parent 69c258f commit fcbdc31

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

modelopt/onnx/quantization/fp8.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def quantize(
272272
trt_guided_options["group_qdq_tensors"] = group_qdq_tensors
273273
logger.debug(f"Grouping QDQ tensors for concat elimination: {group_qdq_tensors}")
274274

275+
# Add disable_int32_weight_adjustment flag to extra options
276+
trt_guided_options["QDQDisableWeightAdjustForInt32Bias"] = True
277+
logger.debug("Disabled weight adjustment for INT32 bias in QDQ quantization")
278+
275279
# Create a temp file for intermediate model
276280
tmp_onnx_file, tmp_onnx_path = tempfile.mkstemp(suffix=".onnx")
277281
os.close(tmp_onnx_file)

modelopt/onnx/quantization/int8.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def quantize(
237237
trt_guided_options["group_qdq_tensors"] = group_qdq_tensors
238238
logger.debug(f"Found {len(group_qdq_tensors)} tensor groups for concat elimination")
239239

240+
# Add disable_int32_weight_adjustment flag to extra options
241+
trt_guided_options["QDQDisableWeightAdjustForInt32Bias"] = True
242+
logger.debug("Disabled weight adjustment for INT32 bias in QDQ quantization")
243+
240244
# Create a temp file for intermediate model
241245
tmp_onnx_file, tmp_onnx_path = tempfile.mkstemp(suffix=".onnx")
242246
os.close(tmp_onnx_file)

modelopt/onnx/quantization/ort_patching.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,7 @@ def _quantize_static(
16001600
("TrtExtraPluginLibraryPaths", "trt_extra_plugin_lib_paths"),
16011601
("ExecutionProviders", "execution_providers"),
16021602
("group_qdq_tensors", "group_qdq_tensors"),
1603+
("QDQDisableWeightAdjustForInt32Bias", "disable_int32_weight_adjustment"),
16031604
# ==========================================================
16041605
]
16051606
calib_extra_options = {

0 commit comments

Comments
 (0)