Skip to content

Commit ea1f10f

Browse files
i-riyadbenchislett
authored andcommitted
Making stronglyTyped default for modelopt evaluation (#287)
Signed-off-by: Riyad Islam <[email protected]>
1 parent bf76c4d commit ea1f10f

File tree

15 files changed

+45
-45
lines changed

15 files changed

+45
-45
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ Model Optimizer Changelog (Linux)
55
^^^^^^^^^^^^^^^^^
66

77
**Deprecations**
8+
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.
89

910
**Bug Fixes**
1011

1112
**New Features**
13+
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1214

1315
0.35 (2025-09-04)
1416
^^^^^^^^^^^^^^^^^

examples/onnx_ptq/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ The following evaluation requires the `val` directory of the [ImageNet dataset](
120120
python evaluate.py \
121121
--onnx_path=<path to classification model> \
122122
--imagenet_path=<path to the ImageNet dataset> \
123-
--quantize_mode=<fp8|int8|int4> \
123+
--engine_precision=stronglyTyped \
124124
--model_name=vit_base_patch16_224
125125
```
126126

@@ -165,7 +165,7 @@ If the input model is of type image classification, use the following script to
165165
python evaluate.py \
166166
--onnx_path=<path to the exported ONNX model> \
167167
--imagenet_path=<path to the ImageNet dataset> \
168-
--quantize_mode=stronglyTyped \
168+
--engine_precision=stronglyTyped \
169169
--model_name=vit_base_patch16_224
170170
```
171171

examples/onnx_ptq/evaluate.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,22 @@ def main():
4848
parser.add_argument(
4949
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
5050
)
51-
# By default, TensorRT autotunes tensor types to generate the fastest engine. When you specify
52-
# to TensorRT that a network is strongly typed, it infers a type for each intermediate and
53-
# output tensor using the rules in the operator type specification. For networks quantized in
54-
# INT4 or FP8 mode, stronglyTyped as the mode is recommended for TensorRT deployment. Though
55-
# INT8 networks are generally compiled with int8 mode, certain INT8 ViT networks compiled with
56-
# stronglyTyped precision have shown better performance.
5751
parser.add_argument(
58-
"--quantize_mode",
52+
"--engine_precision",
5953
type=str,
6054
default="stronglyTyped",
61-
choices=["fp8", "fp16", "fp32", "int4", "int8", "int8_iq", "bf16", "best", "stronglyTyped"],
62-
help="Quantization mode for the TensorRT engine. \
63-
Supported options: fp8, fp16, fp32, int8, int8_iq(implicit quantization), bf16, best, stronglyTyped",
55+
choices=["best", "fp16", "stronglyTyped"],
56+
help="Precision mode for the TensorRT engine. \
57+
stronglyTyped is recommended, all other modes have been deprecated in TensorRT",
6458
)
6559
parser.add_argument(
6660
"--results_path", type=str, default=None, help="Save the results to the specified path"
6761
)
6862

6963
args = parser.parse_args()
70-
7164
deployment = {
7265
"runtime": "TRT",
73-
"precision": args.quantize_mode,
66+
"precision": args.engine_precision,
7467
}
7568

7669
# Create an ONNX bytes object with the specified path

examples/onnx_ptq/evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
deployment = {
3030
"runtime": "TRT",
3131
"accelerator": "GPU",
32-
"precision": "fp32",
32+
"precision": "stronglyTyped",
3333
"onnx_opset": "21",
3434
}
3535

modelopt/onnx/quantization/__main__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def get_parser() -> argparse.ArgumentParser:
180180
argparser.add_argument(
181181
"--high_precision_dtype",
182182
type=str,
183-
default=None,
183+
default="fp16",
184184
choices=["fp32", "fp16", "bf16"],
185185
help=(
186-
"High precision data type, one of ['fp32', 'fp16', 'bf16']. For int8 quantization, the default value is "
187-
"'fp32' and 'fp16' for other quantization modes."
186+
"High precision data type of the output model. If the input model is of dtype fp32, "
187+
"it will be converted to fp16 dtype by default."
188188
),
189189
)
190190
argparser.add_argument(
@@ -262,8 +262,6 @@ def main():
262262
# Convert the NpzFile object to a Python dictionary
263263
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
264264

265-
default_high_precision_dtype = "fp32" if args.quantize_mode == "int8" else "fp16"
266-
267265
quantize(
268266
args.onnx_path,
269267
quantize_mode=args.quantize_mode,
@@ -284,7 +282,7 @@ def main():
284282
log_file=args.log_file,
285283
trt_plugins=args.trt_plugins,
286284
trt_plugins_precision=args.trt_plugins_precision,
287-
high_precision_dtype=args.high_precision_dtype or default_high_precision_dtype,
285+
high_precision_dtype=args.high_precision_dtype,
288286
mha_accumulation_dtype=args.mha_accumulation_dtype,
289287
disable_mha_qdq=args.disable_mha_qdq,
290288
dq_only=args.dq_only,

modelopt/onnx/quantization/int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def quantize(
124124
use_external_data_format: bool = False,
125125
intermediate_generated_files: list[str] = [],
126126
trt_extra_plugin_lib_paths: list[str] | None = None,
127-
high_precision_dtype: str = "fp32",
127+
high_precision_dtype: str = "fp16",
128128
passes: list[str] = ["concat_elimination"],
129129
log_level: str = "INFO",
130130
calibrate_per_node: bool = False,

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ def _get_successive_consumers(
529529
quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0]
530530
if not quantized_node:
531531
raise ValueError(f"No consumer found for {dq_node.name}")
532+
if quantized_node.op_type == "Cast":
533+
next_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
534+
if not next_node:
535+
raise ValueError(f"No consumer found after Cast for {quantized_node.name}")
536+
quantized_node = next_node
532537

533538
return dq_node, quantized_node
534539

@@ -1030,7 +1035,7 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
10301035

10311036
# Remove transpose and reshape nodes
10321037
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
1033-
graph.node.clear()
1038+
del graph.node[:]
10341039
graph.node.extend(new_nodes)
10351040

10361041
def is_fp32_cast(node: onnx.NodeProto) -> bool:

modelopt/onnx/quantization/quantize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def quantize(
219219
log_file: str | None = None,
220220
trt_plugins: list[str] | None = None,
221221
trt_plugins_precision: list[str] | None = None,
222-
high_precision_dtype: str | None = None,
222+
high_precision_dtype: str = "fp16",
223223
mha_accumulation_dtype: str = "fp16",
224224
disable_mha_qdq: bool = False,
225225
dq_only: bool = True,
@@ -286,12 +286,13 @@ def quantize(
286286
Each item should have the format <op_type>:<precision>, where precision can be fp32 (default) or fp16.
287287
For example: op_type_1:fp16 op_type_2:fp32.
288288
high_precision_dtype:
289-
High precision data type, one of ['fp32', 'fp16']. If high_precision_dtype == 'fp16', model's weight and
290-
activation will be converted to fp16.
289+
High precision data type of the output model. If high_precision_dtype is 'fp16' or 'bf16'
290+
and the input model is of dtype fp32, model's weight and activation will be converted to
291+
'fp16' or 'bf16'.
291292
mha_accumulation_dtype:
292-
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default.
293-
If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to
294-
MHA's bmm1 and bmm2's input and output tensors.
293+
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. If quantize_mode == 'fp8' and
294+
mha_accumulation_dtype == 'fp32', Cast nodes will be added to MHA's bmm1 and bmm2's input
295+
and output tensors.
295296
disable_mha_qdq:
296297
Don't add Q/DQ layers to MatMuls in MHA pattern.
297298
dq_only:
@@ -461,7 +462,7 @@ def quantize(
461462
use_external_data_format=use_external_data_format,
462463
intermediate_generated_files=intermediate_generated_files,
463464
trt_extra_plugin_lib_paths=trt_plugins,
464-
high_precision_dtype=high_precision_dtype, # type: ignore[arg-type]
465+
high_precision_dtype=high_precision_dtype,
465466
mha_accumulation_dtype=mha_accumulation_dtype,
466467
passes=passes,
467468
log_level=log_level,

modelopt/torch/_deploy/_runtime/tensorrt/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class TRTMode:
8686
BFLOAT16 = "bf16"
8787
FLOAT8 = "fp8"
8888
INT8 = "int8"
89-
INT8_IQ = "int8_iq"
9089
INT4 = "int4"
9190
STRONGLY_TYPED = "stronglyTyped"
9291
BEST = "best"
@@ -98,7 +97,6 @@ class TRTMode:
9897
TRTMode.BFLOAT16: ["--bf16"],
9998
TRTMode.FLOAT8: ["--fp16", "--fp8"],
10099
TRTMode.INT8: ["--fp16", "--int8"],
101-
TRTMode.INT8_IQ: ["--int8"],
102100
TRTMode.INT4: ["--fp16", "--int4"],
103101
TRTMode.STRONGLY_TYPED: ["--stronglyTyped"],
104102
TRTMode.BEST: ["--best"],

modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def _get_trtexec_params(
103103
def _is_low_bit_mode(trt_mode: str) -> bool:
104104
return trt_mode in [
105105
TRTMode.INT8,
106-
TRTMode.INT8_IQ,
107106
TRTMode.INT4,
108107
TRTMode.FLOAT8,
109108
TRTMode.BEST,
109+
TRTMode.STRONGLY_TYPED,
110110
]
111111

112112

0 commit comments

Comments
 (0)