Skip to content

Commit 12911e6

Browse files
committed
Making stronglyTyped default for modelopt evaluation
Signed-off-by: Riyad Islam <[email protected]>
1 parent 29b4cf2 commit 12911e6

File tree

12 files changed

+33
-41
lines changed

12 files changed

+33
-41
lines changed

examples/onnx_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ The following evaluation requires the `val` directory of the [ImageNet dataset](
120120
python evaluate.py \
121121
--onnx_path=<path to classification model> \
122122
--imagenet_path=<path to the ImageNet dataset> \
123-
--quantize_mode=<fp8|int8|int4> \
123+
--quantize_mode=stronglyTyped \
124124
--model_name=vit_base_patch16_224
125125
```
126126

examples/onnx_ptq/evaluate.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,16 @@ def main():
4848
parser.add_argument(
4949
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
5050
)
51-
# By default, TensorRT autotunes tensor types to generate the fastest engine. When you specify
52-
# to TensorRT that a network is strongly typed, it infers a type for each intermediate and
53-
# output tensor using the rules in the operator type specification. For networks quantized in
54-
# INT4 or FP8 mode, stronglyTyped as the mode is recommended for TensorRT deployment. Though
55-
# INT8 networks are generally compiled with int8 mode, certain INT8 ViT networks compiled with
56-
# stronglyTyped precision have shown better performance.
57-
parser.add_argument(
58-
"--quantize_mode",
59-
type=str,
60-
default="stronglyTyped",
61-
choices=["fp8", "fp16", "fp32", "int4", "int8", "int8_iq", "bf16", "best", "stronglyTyped"],
62-
help="Quantization mode for the TensorRT engine. \
63-
Supported options: fp8, fp16, fp32, int8, int8_iq(implicit quantization), bf16, best, stronglyTyped",
64-
)
6551
parser.add_argument(
6652
"--results_path", type=str, default=None, help="Save the results to the specified path"
6753
)
6854

6955
args = parser.parse_args()
7056

57+
# Note. stronglyTyped is recommended, all other modes have been deprecated in TensorRT
7158
deployment = {
7259
"runtime": "TRT",
73-
"precision": args.quantize_mode,
60+
"precision": "stronglyTyped",
7461
}
7562

7663
# Create an ONNX bytes object with the specified path

examples/onnx_ptq/evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
deployment = {
3030
"runtime": "TRT",
3131
"accelerator": "GPU",
32-
"precision": "fp32",
32+
"precision": "stronglyTyped",
3333
"onnx_opset": "21",
3434
}
3535

modelopt/onnx/quantization/__main__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def get_parser() -> argparse.ArgumentParser:
180180
argparser.add_argument(
181181
"--high_precision_dtype",
182182
type=str,
183-
default=None,
183+
default="fp16",
184184
choices=["fp32", "fp16", "bf16"],
185185
help=(
186-
"High precision data type, one of ['fp32', 'fp16', 'bf16']. For int8 quantization, the default value is "
187-
"'fp32' and 'fp16' for other quantization modes."
186+
"High precision data type of the output model. If the input model is of dtype fp32, "
187+
"it will be converted to fp16 dtype by default."
188188
),
189189
)
190190
argparser.add_argument(
@@ -262,8 +262,6 @@ def main():
262262
# Convert the NpzFile object to a Python dictionary
263263
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
264264

265-
default_high_precision_dtype = "fp32" if args.quantize_mode == "int8" else "fp16"
266-
267265
quantize(
268266
args.onnx_path,
269267
quantize_mode=args.quantize_mode,
@@ -284,7 +282,7 @@ def main():
284282
log_file=args.log_file,
285283
trt_plugins=args.trt_plugins,
286284
trt_plugins_precision=args.trt_plugins_precision,
287-
high_precision_dtype=args.high_precision_dtype or default_high_precision_dtype,
285+
high_precision_dtype=args.high_precision_dtype,
288286
mha_accumulation_dtype=args.mha_accumulation_dtype,
289287
disable_mha_qdq=args.disable_mha_qdq,
290288
dq_only=args.dq_only,

modelopt/onnx/quantization/int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def quantize(
124124
use_external_data_format: bool = False,
125125
intermediate_generated_files: list[str] = [],
126126
trt_extra_plugin_lib_paths: list[str] | None = None,
127-
high_precision_dtype: str = "fp32",
127+
high_precision_dtype: str = "fp16",
128128
passes: list[str] = ["concat_elimination"],
129129
log_level: str = "INFO",
130130
calibrate_per_node: bool = False,

modelopt/onnx/quantization/quantize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def quantize(
219219
log_file: str | None = None,
220220
trt_plugins: list[str] | None = None,
221221
trt_plugins_precision: list[str] | None = None,
222-
high_precision_dtype: str | None = None,
222+
high_precision_dtype: str = "fp16",
223223
mha_accumulation_dtype: str = "fp16",
224224
disable_mha_qdq: bool = False,
225225
dq_only: bool = True,
@@ -286,12 +286,13 @@ def quantize(
286286
Each item should have the format <op_type>:<precision>, where precision can be fp32 (default) or fp16.
287287
For example: op_type_1:fp16 op_type_2:fp32.
288288
high_precision_dtype:
289-
High precision data type, one of ['fp32', 'fp16']. If high_precision_dtype == 'fp16', model's weight and
290-
activation will be converted to fp16.
289+
High precision data type of the output model. If high_precision_dtype is 'fp16' or 'bf16'
290+
and the input model is of dtype fp32, model's weight and activation will be converted to
291+
'fp16' or 'bf16'.
291292
mha_accumulation_dtype:
292-
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default.
293-
If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to
294-
MHA's bmm1 and bmm2's input and output tensors.
293+
MHA accumulation dtype. One of ['fp32', 'fp16']. If quantize_mode == 'fp8' and
294+
mha_accumulation_dtype == 'fp32', Cast nodes will be added to MHA's bmm1 and bmm2's input
295+
and output tensors.
295296
disable_mha_qdq:
296297
Don't add Q/DQ layers to MatMuls in MHA pattern.
297298
dq_only:
@@ -461,7 +462,7 @@ def quantize(
461462
use_external_data_format=use_external_data_format,
462463
intermediate_generated_files=intermediate_generated_files,
463464
trt_extra_plugin_lib_paths=trt_plugins,
464-
high_precision_dtype=high_precision_dtype, # type: ignore[arg-type]
465+
high_precision_dtype=high_precision_dtype,
465466
mha_accumulation_dtype=mha_accumulation_dtype,
466467
passes=passes,
467468
log_level=log_level,

modelopt/torch/_deploy/_runtime/tensorrt/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class TRTMode:
8686
BFLOAT16 = "bf16"
8787
FLOAT8 = "fp8"
8888
INT8 = "int8"
89-
INT8_IQ = "int8_iq"
9089
INT4 = "int4"
9190
STRONGLY_TYPED = "stronglyTyped"
9291
BEST = "best"
@@ -98,7 +97,6 @@ class TRTMode:
9897
TRTMode.BFLOAT16: ["--bf16"],
9998
TRTMode.FLOAT8: ["--fp16", "--fp8"],
10099
TRTMode.INT8: ["--fp16", "--int8"],
101-
TRTMode.INT8_IQ: ["--int8"],
102100
TRTMode.INT4: ["--fp16", "--int4"],
103101
TRTMode.STRONGLY_TYPED: ["--stronglyTyped"],
104102
TRTMode.BEST: ["--best"],

modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def _get_trtexec_params(
103103
def _is_low_bit_mode(trt_mode: str) -> bool:
104104
return trt_mode in [
105105
TRTMode.INT8,
106-
TRTMode.INT8_IQ,
107106
TRTMode.INT4,
108107
TRTMode.FLOAT8,
109108
TRTMode.BEST,
109+
TRTMode.STRONGLY_TYPED,
110110
]
111111

112112

modelopt/torch/_deploy/_runtime/trt_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def deployment_table(self) -> DeploymentTable:
4949
"bf16",
5050
"fp8",
5151
"int8",
52-
"int8_iq",
5352
"int4",
5453
"stronglyTyped",
5554
"best",

tests/_test_utils/onnx_quantization/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ def _assert_nodes_are_quantized(nodes):
2020
for node in nodes:
2121
for inp_idx, inp in enumerate(node.inputs):
2222
if isinstance(inp, gs.Variable):
23-
assert node.i(inp_idx).op == "DequantizeLinear", (
24-
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
23+
qnode = node
24+
# After quantization, the quantized node can be casted
25+
if qnode.i(inp_idx).op == "Cast":
26+
qnode = qnode.i(inp_idx)
27+
assert qnode.i(inp_idx).op == "DequantizeLinear", (
28+
f"Input '{inp.name}' of node '{qnode.name}' is not quantized but should be!"
2529
)
2630
return True

0 commit comments

Comments
 (0)