Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ Model Optimizer Changelog (Linux)
^^^^^^^^^^^^^^^^^

**Deprecations**
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.

**Bug Fixes**

**New Features**
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.

0.35 (2025-09-04)
^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ The following evaluation requires the `val` directory of the [ImageNet dataset](
python evaluate.py \
--onnx_path=<path to classification model> \
--imagenet_path=<path to the ImageNet dataset> \
--quantize_mode=<fp8|int8|int4> \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
```

Expand Down Expand Up @@ -165,7 +165,7 @@ If the input model is of type image classification, use the following script to
python evaluate.py \
--onnx_path=<path to the exported ONNX model> \
--imagenet_path=<path to the ImageNet dataset> \
--quantize_mode=stronglyTyped \
--engine_precision=stronglyTyped \
--model_name=vit_base_patch16_224
```

Expand Down
17 changes: 5 additions & 12 deletions examples/onnx_ptq/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,22 @@ def main():
parser.add_argument(
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
)
# By default, TensorRT autotunes tensor types to generate the fastest engine. When you specify
# to TensorRT that a network is strongly typed, it infers a type for each intermediate and
# output tensor using the rules in the operator type specification. For networks quantized in
# INT4 or FP8 mode, stronglyTyped as the mode is recommended for TensorRT deployment. Though
# INT8 networks are generally compiled with int8 mode, certain INT8 ViT networks compiled with
# stronglyTyped precision have shown better performance.
parser.add_argument(
"--quantize_mode",
"--engine_precision",
type=str,
default="stronglyTyped",
choices=["fp8", "fp16", "fp32", "int4", "int8", "int8_iq", "bf16", "best", "stronglyTyped"],
help="Quantization mode for the TensorRT engine. \
Supported options: fp8, fp16, fp32, int8, int8_iq(implicit quantization), bf16, best, stronglyTyped",
choices=["best", "fp16", "stronglyTyped"],
help="Precision mode for the TensorRT engine. \
stronglyTyped is recommended, all other modes have been deprecated in TensorRT",
)
parser.add_argument(
"--results_path", type=str, default=None, help="Save the results to the specified path"
)

args = parser.parse_args()

deployment = {
"runtime": "TRT",
"precision": args.quantize_mode,
"precision": args.engine_precision,
}

# Create an ONNX bytes object with the specified path
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx_ptq/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
deployment = {
"runtime": "TRT",
"accelerator": "GPU",
"precision": "fp32",
"precision": "stronglyTyped",
"onnx_opset": "21",
}

Expand Down
10 changes: 4 additions & 6 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def get_parser() -> argparse.ArgumentParser:
argparser.add_argument(
"--high_precision_dtype",
type=str,
default=None,
default="fp16",
choices=["fp32", "fp16", "bf16"],
help=(
"High precision data type, one of ['fp32', 'fp16', 'bf16']. For int8 quantization, the default value is "
"'fp32' and 'fp16' for other quantization modes."
"High precision data type of the output model. If the input model is of dtype fp32, "
"it will be converted to fp16 dtype by default."
),
)
argparser.add_argument(
Expand Down Expand Up @@ -262,8 +262,6 @@ def main():
# Convert the NpzFile object to a Python dictionary
calibration_data = {key: calibration_data[key] for key in calibration_data.files}

default_high_precision_dtype = "fp32" if args.quantize_mode == "int8" else "fp16"

quantize(
args.onnx_path,
quantize_mode=args.quantize_mode,
Expand All @@ -284,7 +282,7 @@ def main():
log_file=args.log_file,
trt_plugins=args.trt_plugins,
trt_plugins_precision=args.trt_plugins_precision,
high_precision_dtype=args.high_precision_dtype or default_high_precision_dtype,
high_precision_dtype=args.high_precision_dtype,
mha_accumulation_dtype=args.mha_accumulation_dtype,
disable_mha_qdq=args.disable_mha_qdq,
dq_only=args.dq_only,
Expand Down
2 changes: 1 addition & 1 deletion modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def quantize(
use_external_data_format: bool = False,
intermediate_generated_files: list[str] = [],
trt_extra_plugin_lib_paths: list[str] | None = None,
high_precision_dtype: str = "fp32",
high_precision_dtype: str = "fp16",
passes: list[str] = ["concat_elimination"],
log_level: str = "INFO",
calibrate_per_node: bool = False,
Expand Down
7 changes: 6 additions & 1 deletion modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def _get_successive_consumers(
quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0]
if not quantized_node:
raise ValueError(f"No consumer found for {dq_node.name}")
if quantized_node.op_type == "Cast":
next_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
if not next_node:
raise ValueError(f"No consumer found after Cast for {quantized_node.name}")
quantized_node = next_node

return dq_node, quantized_node

Expand Down Expand Up @@ -997,7 +1002,7 @@ def quantize_weights_to_int4(

# Remove transpose and reshape nodes
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
graph.node.clear()
del graph.node[:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we use this over graph.node.clear() ?

graph.node.extend(new_nodes)

def is_fp32_cast(node: onnx.NodeProto) -> bool:
Expand Down
15 changes: 8 additions & 7 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def quantize(
log_file: str | None = None,
trt_plugins: list[str] | None = None,
trt_plugins_precision: list[str] | None = None,
high_precision_dtype: str | None = None,
high_precision_dtype: str = "fp16",
mha_accumulation_dtype: str = "fp16",
disable_mha_qdq: bool = False,
dq_only: bool = True,
Expand Down Expand Up @@ -286,12 +286,13 @@ def quantize(
Each item should have the format <op_type>:<precision>, where precision can be fp32 (default) or fp16.
For example: op_type_1:fp16 op_type_2:fp32.
high_precision_dtype:
High precision data type, one of ['fp32', 'fp16']. If high_precision_dtype == 'fp16', model's weight and
activation will be converted to fp16.
High precision data type of the output model. If high_precision_dtype is 'fp16' or 'bf16'
and the input model is of dtype fp32, model's weight and activation will be converted to
'fp16' or 'bf16'.
mha_accumulation_dtype:
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default.
If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to
MHA's bmm1 and bmm2's input and output tensors.
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. If quantize_mode == 'fp8' and
mha_accumulation_dtype == 'fp32', Cast nodes will be added to MHA's bmm1 and bmm2's input
and output tensors.
disable_mha_qdq:
Don't add Q/DQ layers to MatMuls in MHA pattern.
dq_only:
Expand Down Expand Up @@ -461,7 +462,7 @@ def quantize(
use_external_data_format=use_external_data_format,
intermediate_generated_files=intermediate_generated_files,
trt_extra_plugin_lib_paths=trt_plugins,
high_precision_dtype=high_precision_dtype, # type: ignore[arg-type]
high_precision_dtype=high_precision_dtype,
mha_accumulation_dtype=mha_accumulation_dtype,
passes=passes,
log_level=log_level,
Expand Down
2 changes: 0 additions & 2 deletions modelopt/torch/_deploy/_runtime/tensorrt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class TRTMode:
BFLOAT16 = "bf16"
FLOAT8 = "fp8"
INT8 = "int8"
INT8_IQ = "int8_iq"
INT4 = "int4"
STRONGLY_TYPED = "stronglyTyped"
BEST = "best"
Expand All @@ -98,7 +97,6 @@ class TRTMode:
TRTMode.BFLOAT16: ["--bf16"],
TRTMode.FLOAT8: ["--fp16", "--fp8"],
TRTMode.INT8: ["--fp16", "--int8"],
TRTMode.INT8_IQ: ["--int8"],
TRTMode.INT4: ["--fp16", "--int4"],
TRTMode.STRONGLY_TYPED: ["--stronglyTyped"],
TRTMode.BEST: ["--best"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _get_trtexec_params(
def _is_low_bit_mode(trt_mode: str) -> bool:
return trt_mode in [
TRTMode.INT8,
TRTMode.INT8_IQ,
TRTMode.INT4,
TRTMode.FLOAT8,
TRTMode.BEST,
TRTMode.STRONGLY_TYPED,
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from modelopt.onnx.utils import get_batch_size
from modelopt.onnx.utils import get_input_names as get_onnx_input_names

from .constants import TENSORRT_8_MAJOR_VERSION, TRTMode
from .constants import TENSORRT_8_MAJOR_VERSION


def is_trt8():
Expand Down Expand Up @@ -131,11 +131,6 @@ def get_output_shapes(
return output_shapes


def validate_precision(precision: str) -> bool:
"""Returns whether an input precision is in supported set."""
return precision in [TRTMode.FLOAT32, TRTMode.FLOAT16, TRTMode.INT8]


def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]):
"""The calibation data generator that yields calibration feed_dict to tensorrt."""
input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes))
Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/_deploy/_runtime/trt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def deployment_table(self) -> DeploymentTable:
"bf16",
"fp8",
"int8",
"int8_iq",
"int4",
"stronglyTyped",
"best",
Expand Down
6 changes: 5 additions & 1 deletion tests/_test_utils/onnx_quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def _assert_nodes_are_quantized(nodes):
for node in nodes:
for inp_idx, inp in enumerate(node.inputs):
if isinstance(inp, gs.Variable):
assert node.i(inp_idx).op == "DequantizeLinear", (
producer = node.i(inp_idx)
# Quantized path may include a Cast right after DQ
if producer and producer.op == "Cast":
producer = producer.i(0)
assert producer and producer.op == "DequantizeLinear", (
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
)
return True
11 changes: 8 additions & 3 deletions tests/examples/test_onnx_ptq.sh
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,23 @@ for model_path in "${model_paths[@]}"; do
quant_mode="${all_modes[$i]}"
gpu_id=$((i % nvidia_gpu_count))

if [ "$quant_mode" == "fp16" ] || [ "$quant_mode" == "int8_iq" ]; then
if [ "$quant_mode" == "fp16" ]; then
eval_model_path=$model_dir/fp16/model.onnx
precision="fp16"
elif [ "$quant_mode" == "int8_iq" ]; then
eval_model_path=$model_dir/fp16/model.onnx
precision="best"
else
eval_model_path=$model_dir/$quant_mode/model.quant.onnx
precision="stronglyTyped"
fi

echo "Starting evaluation of $model_name for mode: $quant_mode on GPU $gpu_id"
if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--model_name="${timm_model_name[$model_name]}" \
--quantize_mode=$quant_mode \
--engine_precision=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
else
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
Expand All @@ -181,7 +186,7 @@ for model_path in "${model_paths[@]}"; do
--eval_data_size=$calib_size \
--batch_size $batch_size \
--model_name="${timm_model_name[$model_name]}" \
--quantize_mode=$quant_mode \
--engine_precision=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
fi
pids+=($!)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/onnx/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _assert_nodes_are_quantized(nodes):
def _assert_nodes_are_not_quantized(nodes):
for node in nodes:
for inp_idx, inp in enumerate(node.inputs):
if isinstance(inp, gs.Variable):
if isinstance(inp, gs.Variable) and inp.inputs:
assert node.i(inp_idx).op != "DequantizeLinear", (
f"Input '{inp.name}' of node '{node.name}' is quantized but shouldn't be!"
)
Expand Down
Loading