Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ Model Optimizer Changelog (Linux)
^^^^^^^^^^^^^^^^^

**Deprecations**
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.

**Bug Fixes**

**New Features**
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.

0.35 (2025-09-04)
^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ The following evaluation requires the `val` directory of the [ImageNet dataset](
python evaluate.py \
--onnx_path=<path to classification model> \
--imagenet_path=<path to the ImageNet dataset> \
--quantize_mode=<fp8|int8|int4> \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
```

Expand Down Expand Up @@ -165,7 +165,7 @@ If the input model is of type image classification, use the following script to
python evaluate.py \
--onnx_path=<path to the exported ONNX model> \
--imagenet_path=<path to the ImageNet dataset> \
--quantize_mode=stronglyTyped \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
```

Expand Down
17 changes: 5 additions & 12 deletions examples/onnx_ptq/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,22 @@ def main():
parser.add_argument(
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
)
# By default, TensorRT autotunes tensor types to generate the fastest engine. When you specify
# to TensorRT that a network is strongly typed, it infers a type for each intermediate and
# output tensor using the rules in the operator type specification. For networks quantized in
# INT4 or FP8 mode, stronglyTyped as the mode is recommended for TensorRT deployment. Though
# INT8 networks are generally compiled with int8 mode, certain INT8 ViT networks compiled with
# stronglyTyped precision have shown better performance.
parser.add_argument(
"--quantize_mode",
"--engine_precision",
type=str,
default="stronglyTyped",
choices=["fp8", "fp16", "fp32", "int4", "int8", "int8_iq", "bf16", "best", "stronglyTyped"],
help="Quantization mode for the TensorRT engine. \
Supported options: fp8, fp16, fp32, int8, int8_iq(implicit quantization), bf16, best, stronglyTyped",
choices=["best", "fp16", "stronglyTyped"],
help="Precision mode for the TensorRT engine. \
stronglyTyped is recommended, all other modes have been deprecated in TensorRT",
)
parser.add_argument(
"--results_path", type=str, default=None, help="Save the results to the specified path"
)

args = parser.parse_args()

deployment = {
"runtime": "TRT",
"precision": args.quantize_mode,
"precision": args.engine_precision,
}

# Create an ONNX bytes object with the specified path
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx_ptq/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
deployment = {
"runtime": "TRT",
"accelerator": "GPU",
"precision": "fp32",
"precision": "stronglyTyped",
"onnx_opset": "21",
}

Expand Down
10 changes: 4 additions & 6 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def get_parser() -> argparse.ArgumentParser:
argparser.add_argument(
"--high_precision_dtype",
type=str,
default=None,
default="fp16",
choices=["fp32", "fp16", "bf16"],
help=(
"High precision data type, one of ['fp32', 'fp16', 'bf16']. For int8 quantization, the default value is "
"'fp32' and 'fp16' for other quantization modes."
"High precision data type of the output model. If the input model is of dtype fp32, "
"it will be converted to fp16 dtype by default."
),
)
argparser.add_argument(
Expand Down Expand Up @@ -262,8 +262,6 @@ def main():
# Convert the NpzFile object to a Python dictionary
calibration_data = {key: calibration_data[key] for key in calibration_data.files}

default_high_precision_dtype = "fp32" if args.quantize_mode == "int8" else "fp16"

quantize(
args.onnx_path,
quantize_mode=args.quantize_mode,
Expand All @@ -284,7 +282,7 @@ def main():
log_file=args.log_file,
trt_plugins=args.trt_plugins,
trt_plugins_precision=args.trt_plugins_precision,
high_precision_dtype=args.high_precision_dtype or default_high_precision_dtype,
high_precision_dtype=args.high_precision_dtype,
mha_accumulation_dtype=args.mha_accumulation_dtype,
disable_mha_qdq=args.disable_mha_qdq,
dq_only=args.dq_only,
Expand Down
2 changes: 1 addition & 1 deletion modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def quantize(
use_external_data_format: bool = False,
intermediate_generated_files: list[str] = [],
trt_extra_plugin_lib_paths: list[str] | None = None,
high_precision_dtype: str = "fp32",
high_precision_dtype: str = "fp16",
passes: list[str] = ["concat_elimination"],
log_level: str = "INFO",
calibrate_per_node: bool = False,
Expand Down
7 changes: 6 additions & 1 deletion modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def _get_successive_consumers(
quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0]
if not quantized_node:
raise ValueError(f"No consumer found for {dq_node.name}")
if quantized_node.op_type == "Cast":
next_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
if not next_node:
raise ValueError(f"No consumer found after Cast for {quantized_node.name}")
quantized_node = next_node

return dq_node, quantized_node

Expand Down Expand Up @@ -992,7 +997,7 @@ def quantize_weights_to_int4(

# Remove transpose and reshape nodes
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
graph.node.clear()
del graph.node[:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we use this over graph.node.clear() ?

graph.node.extend(new_nodes)

def is_fp32_cast(node: onnx.NodeProto) -> bool:
Expand Down
15 changes: 8 additions & 7 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def quantize(
log_file: str | None = None,
trt_plugins: list[str] | None = None,
trt_plugins_precision: list[str] | None = None,
high_precision_dtype: str | None = None,
high_precision_dtype: str = "fp16",
mha_accumulation_dtype: str = "fp16",
disable_mha_qdq: bool = False,
dq_only: bool = True,
Expand Down Expand Up @@ -286,12 +286,13 @@ def quantize(
Each item should have the format <op_type>:<precision>, where precision can be fp32 (default) or fp16.
For example: op_type_1:fp16 op_type_2:fp32.
high_precision_dtype:
High precision data type, one of ['fp32', 'fp16']. If high_precision_dtype == 'fp16', model's weight and
activation will be converted to fp16.
High precision data type of the output model. If high_precision_dtype is 'fp16' or 'bf16'
and the input model is of dtype fp32, model's weight and activation will be converted to
'fp16' or 'bf16'.
mha_accumulation_dtype:
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default.
If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to
MHA's bmm1 and bmm2's input and output tensors.
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. If quantize_mode == 'fp8' and
mha_accumulation_dtype == 'fp32', Cast nodes will be added to MHA's bmm1 and bmm2's input
and output tensors.
disable_mha_qdq:
Don't add Q/DQ layers to MatMuls in MHA pattern.
dq_only:
Expand Down Expand Up @@ -461,7 +462,7 @@ def quantize(
use_external_data_format=use_external_data_format,
intermediate_generated_files=intermediate_generated_files,
trt_extra_plugin_lib_paths=trt_plugins,
high_precision_dtype=high_precision_dtype, # type: ignore[arg-type]
high_precision_dtype=high_precision_dtype,
mha_accumulation_dtype=mha_accumulation_dtype,
passes=passes,
log_level=log_level,
Expand Down
2 changes: 0 additions & 2 deletions modelopt/torch/_deploy/_runtime/tensorrt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class TRTMode:
BFLOAT16 = "bf16"
FLOAT8 = "fp8"
INT8 = "int8"
INT8_IQ = "int8_iq"
INT4 = "int4"
STRONGLY_TYPED = "stronglyTyped"
BEST = "best"
Expand All @@ -98,7 +97,6 @@ class TRTMode:
TRTMode.BFLOAT16: ["--bf16"],
TRTMode.FLOAT8: ["--fp16", "--fp8"],
TRTMode.INT8: ["--fp16", "--int8"],
TRTMode.INT8_IQ: ["--int8"],
TRTMode.INT4: ["--fp16", "--int4"],
TRTMode.STRONGLY_TYPED: ["--stronglyTyped"],
TRTMode.BEST: ["--best"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _get_trtexec_params(
def _is_low_bit_mode(trt_mode: str) -> bool:
return trt_mode in [
TRTMode.INT8,
TRTMode.INT8_IQ,
TRTMode.INT4,
TRTMode.FLOAT8,
TRTMode.BEST,
TRTMode.STRONGLY_TYPED,
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from modelopt.onnx.utils import get_batch_size
from modelopt.onnx.utils import get_input_names as get_onnx_input_names

from .constants import TENSORRT_8_MAJOR_VERSION, TRTMode
from .constants import TENSORRT_8_MAJOR_VERSION


def is_trt8():
Expand Down Expand Up @@ -131,11 +131,6 @@ def get_output_shapes(
return output_shapes


def validate_precision(precision: str) -> bool:
"""Returns whether an input precision is in supported set."""
return precision in [TRTMode.FLOAT32, TRTMode.FLOAT16, TRTMode.INT8]


def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]):
"""The calibation data generator that yields calibration feed_dict to tensorrt."""
input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes))
Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/_deploy/_runtime/trt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def deployment_table(self) -> DeploymentTable:
"bf16",
"fp8",
"int8",
"int8_iq",
"int4",
"stronglyTyped",
"best",
Expand Down
6 changes: 5 additions & 1 deletion tests/_test_utils/onnx_quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def _assert_nodes_are_quantized(nodes):
for node in nodes:
for inp_idx, inp in enumerate(node.inputs):
if isinstance(inp, gs.Variable):
assert node.i(inp_idx).op == "DequantizeLinear", (
producer = node.i(inp_idx)
# Quantized path may include a Cast right after DQ
if producer and producer.op == "Cast":
producer = producer.i(0)
assert producer and producer.op == "DequantizeLinear", (
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
)
return True
11 changes: 8 additions & 3 deletions tests/examples/test_onnx_ptq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,23 @@ for model_path in "${model_paths[@]}"; do
quant_mode="${all_modes[$i]}"
gpu_id=$((i % nvidia_gpu_count))

if [ "$quant_mode" == "fp16" ] || [ "$quant_mode" == "int8_iq" ]; then
if [ "$quant_mode" == "fp16" ]; then
eval_model_path=$model_dir/fp16/model.onnx
precision="fp16"
elif [ "$quant_mode" == "int8_iq" ]; then
eval_model_path=$model_dir/fp16/model.onnx
precision="best"
else
eval_model_path=$model_dir/$quant_mode/model.quant.onnx
precision="stronglyTyped"
fi

echo "Starting evaluation of $model_name for mode: $quant_mode on GPU $gpu_id"
if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--model_name="${timm_model_name[$model_name]}" \
--quantize_mode=$quant_mode \
--engine_precision=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
else
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
Expand All @@ -181,7 +186,7 @@ for model_path in "${model_paths[@]}"; do
--eval_data_size=$calib_size \
--batch_size $batch_size \
--model_name="${timm_model_name[$model_name]}" \
--quantize_mode=$quant_mode \
--engine_precision=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
fi
pids+=($!)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/onnx/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _assert_nodes_are_quantized(nodes):
def _assert_nodes_are_not_quantized(nodes):
for node in nodes:
for inp_idx, inp in enumerate(node.inputs):
if isinstance(inp, gs.Variable):
if isinstance(inp, gs.Variable) and inp.inputs:
assert node.i(inp_idx).op != "DequantizeLinear", (
f"Input '{inp.name}' of node '{node.name}' is quantized but shouldn't be!"
)
Expand Down