Skip to content

Commit 043d2ce

Browse files
authored
Merge branch 'main' into rislam/qwen-fix
2 parents 0bd0218 + b913290 commit 043d2ce

File tree

23 files changed

+184
-89
lines changed

23 files changed

+184
-89
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ Model Optimizer Changelog (Linux)
55
^^^^^^^^^^^^^^^^^
66

77
**Deprecations**
8+
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.
89

910
**Bug Fixes**
1011

1112
**New Features**
13+
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1214

1315
0.35 (2025-09-04)
1416
^^^^^^^^^^^^^^^^^

examples/onnx_ptq/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ The following evaluation requires the `val` directory of the [ImageNet dataset](
120120
python evaluate.py \
121121
--onnx_path=<path to classification model> \
122122
--imagenet_path=<path to the ImageNet dataset> \
123-
--quantize_mode=<fp8|int8|int4> \
123+
--engine_precision=stronglyTyped \
124124
--model_name=vit_base_patch16_224
125125
```
126126

@@ -165,7 +165,7 @@ If the input model is of type image classification, use the following script to
165165
python evaluate.py \
166166
--onnx_path=<path to the exported ONNX model> \
167167
--imagenet_path=<path to the ImageNet dataset> \
168-
--quantize_mode=stronglyTyped \
168+
--engine_precision=stronglyTyped \
169169
--model_name=vit_base_patch16_224
170170
```
171171

examples/onnx_ptq/docker/Dockerfile

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ RUN python -m pip install --upgrade pip \
1212

1313
WORKDIR /workspace
1414

15-
RUN pip install tensorrt==10.13.2.6 && \
16-
export TRT_PATH=$(python -c "import tensorrt; import os; print(os.path.dirname(tensorrt.__file__))") && \
17-
export LD_LIBRARY_PATH="$TRT_PATH/lib:/usr/include:${LD_LIBRARY_PATH}" && \
18-
export PATH="$TRT_PATH/bin:${PATH}"
15+
RUN pip install tensorrt==10.13.2.6
16+
ENV TRT_PATH=/usr/local/lib/python3.12/dist-packages/tensorrt
17+
ENV CUDNN_LIB_DIR=/usr/local/lib/python3.12/dist-packages/nvidia/cudnn/lib
18+
ENV LD_LIBRARY_PATH="${CUDNN_LIB_DIR}:${TRT_PATH}/lib:/usr/include:${LD_LIBRARY_PATH}"
19+
ENV PATH="${TRT_PATH}/bin:${PATH}"
1920

2021
# Copy application code and install requirements
2122
COPY modelopt modelopt/modelopt

examples/onnx_ptq/evaluate.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,22 @@ def main():
4848
parser.add_argument(
4949
"--eval_data_size", type=int, default=None, help="Number of examples to evaluate"
5050
)
51-
# By default, TensorRT autotunes tensor types to generate the fastest engine. When you specify
52-
# to TensorRT that a network is strongly typed, it infers a type for each intermediate and
53-
# output tensor using the rules in the operator type specification. For networks quantized in
54-
# INT4 or FP8 mode, stronglyTyped as the mode is recommended for TensorRT deployment. Though
55-
# INT8 networks are generally compiled with int8 mode, certain INT8 ViT networks compiled with
56-
# stronglyTyped precision have shown better performance.
5751
parser.add_argument(
58-
"--quantize_mode",
52+
"--engine_precision",
5953
type=str,
6054
default="stronglyTyped",
61-
choices=["fp8", "fp16", "fp32", "int4", "int8", "int8_iq", "bf16", "best", "stronglyTyped"],
62-
help="Quantization mode for the TensorRT engine. \
63-
Supported options: fp8, fp16, fp32, int8, int8_iq(implicit quantization), bf16, best, stronglyTyped",
55+
choices=["best", "fp16", "stronglyTyped"],
56+
help="Precision mode for the TensorRT engine. \
57+
stronglyTyped is recommended, all other modes have been deprecated in TensorRT",
6458
)
6559
parser.add_argument(
6660
"--results_path", type=str, default=None, help="Save the results to the specified path"
6761
)
6862

6963
args = parser.parse_args()
70-
7164
deployment = {
7265
"runtime": "TRT",
73-
"precision": args.quantize_mode,
66+
"precision": args.engine_precision,
7467
}
7568

7669
# Create an ONNX bytes object with the specified path

examples/onnx_ptq/evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
deployment = {
3030
"runtime": "TRT",
3131
"accelerator": "GPU",
32-
"precision": "fp32",
32+
"precision": "stronglyTyped",
3333
"onnx_opset": "21",
3434
}
3535

examples/onnx_ptq/torch_quant_to_onnx.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def forward_loop(model):
8383
return quantized_model
8484

8585

86-
def get_model_input_shape(model_name):
86+
def get_model_input_shape(model_name, batch_size):
8787
"""Get the input shape from timm model configuration."""
8888
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
8989
data_config = timm.data.resolve_model_data_config(model)
9090
input_size = data_config["input_size"]
91-
return (1, *tuple(input_size)) # Add batch dimension
91+
return (batch_size, *tuple(input_size)) # Add batch dimension
9292

9393

9494
def main():
@@ -119,11 +119,17 @@ def main():
119119
default=512,
120120
help="Number of images to use in calibration [1-512]",
121121
)
122+
parser.add_argument(
123+
"--batch_size",
124+
type=int,
125+
default=1,
126+
help="Batch size for calibration and ONNX model export.",
127+
)
122128

123129
args = parser.parse_args()
124130

125131
# Get input shape from model config
126-
input_shape = get_model_input_shape(args.timm_model_name)
132+
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)
127133

128134
# Create model and move to appropriate device
129135
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
datasets>=2.14.5
2+
onnx==1.18.0
23
torch==2.6.0
34
transformers==4.49.0

modelopt/onnx/quantization/__main__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def get_parser() -> argparse.ArgumentParser:
180180
argparser.add_argument(
181181
"--high_precision_dtype",
182182
type=str,
183-
default=None,
183+
default="fp16",
184184
choices=["fp32", "fp16", "bf16"],
185185
help=(
186-
"High precision data type, one of ['fp32', 'fp16', 'bf16']. For int8 quantization, the default value is "
187-
"'fp32' and 'fp16' for other quantization modes."
186+
"High precision data type of the output model. If the input model is of dtype fp32, "
187+
"it will be converted to fp16 dtype by default."
188188
),
189189
)
190190
argparser.add_argument(
@@ -262,8 +262,6 @@ def main():
262262
# Convert the NpzFile object to a Python dictionary
263263
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
264264

265-
default_high_precision_dtype = "fp32" if args.quantize_mode == "int8" else "fp16"
266-
267265
quantize(
268266
args.onnx_path,
269267
quantize_mode=args.quantize_mode,
@@ -284,7 +282,7 @@ def main():
284282
log_file=args.log_file,
285283
trt_plugins=args.trt_plugins,
286284
trt_plugins_precision=args.trt_plugins_precision,
287-
high_precision_dtype=args.high_precision_dtype or default_high_precision_dtype,
285+
high_precision_dtype=args.high_precision_dtype,
288286
mha_accumulation_dtype=args.mha_accumulation_dtype,
289287
disable_mha_qdq=args.disable_mha_qdq,
290288
dq_only=args.dq_only,

modelopt/onnx/quantization/int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def quantize(
124124
use_external_data_format: bool = False,
125125
intermediate_generated_files: list[str] = [],
126126
trt_extra_plugin_lib_paths: list[str] | None = None,
127-
high_precision_dtype: str = "fp32",
127+
high_precision_dtype: str = "fp16",
128128
passes: list[str] = ["concat_elimination"],
129129
log_level: str = "INFO",
130130
calibrate_per_node: bool = False,

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import onnx_graphsurgeon as gs
2424
import torch
2525
from onnx import numpy_helper
26-
from onnx.reference.custom_element_types import float8e4m3fn
2726

2827
from modelopt.onnx import utils
2928
from modelopt.onnx.logging_config import logger
@@ -50,6 +49,7 @@
5049
onnx_dtype_map = {
5150
"BFloat16": onnx.TensorProto.BFLOAT16,
5251
"Float": onnx.TensorProto.FLOAT,
52+
"Float4": onnx.TensorProto.FLOAT4E2M1,
5353
"Float8": onnx.TensorProto.FLOAT8E4M3FN,
5454
"Half": onnx.TensorProto.FLOAT16,
5555
"INT8": onnx.TensorProto.INT8,
@@ -529,6 +529,11 @@ def _get_successive_consumers(
529529
quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0]
530530
if not quantized_node:
531531
raise ValueError(f"No consumer found for {dq_node.name}")
532+
if quantized_node.op_type == "Cast":
533+
next_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
534+
if not next_node:
535+
raise ValueError(f"No consumer found after Cast for {quantized_node.name}")
536+
quantized_node = next_node
532537

533538
return dq_node, quantized_node
534539

@@ -592,7 +597,7 @@ def _convert_weight(
592597
zp_array = zp_array.reshape(*reshape_dims)
593598

594599
# Convert to INT8/FP8
595-
if zp_array.dtype == float8e4m3fn:
600+
if zp_array.dtype == onnx_dtype_map["Float8"]:
596601
scaled = np.asarray(weight_array / scale_array) + zp_array
597602
else:
598603
scaled = np.asarray((weight_array / scale_array).round())
@@ -607,17 +612,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607612
if torch.cuda.is_available():
608613
array_f32_t = array_f32_t.cuda()
609614
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8)
610-
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")]))
615+
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
611616
return array_f8
612617

613618

614619
def _cast_fp4(array: np.ndarray) -> np.ndarray:
615-
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
620+
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.
621+
622+
Note: The first dimension of the array must be divisible by 2
623+
as two FP4 values are packed into a single byte.
624+
"""
616625
array_f32_t = torch.from_numpy(array)
626+
array_f32_t_shape = array_f32_t.shape
627+
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
628+
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
617629
if torch.cuda.is_available():
618630
array_f32_t = array_f32_t.cuda()
619631
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
620-
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")]))
632+
array_f4_t = array_f4_t.flatten()
633+
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
634+
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
621635
return array_f4
622636

623637

@@ -685,7 +699,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685699
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
686700

687701
# Create and update new weight tensor
688-
if zp_array.dtype == float8e4m3fn:
702+
if zp_array.dtype == onnx_dtype_map["Float8"]:
689703
new_weight = _create_fp8_tensor(scaled, weight_name)
690704
logger.debug(f"Converted {weight_name} to FP8")
691705
else:
@@ -925,6 +939,10 @@ def quantize_weights_to_int4(
925939
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}"
926940
reshape_node_output = reshape_node.output[0]
927941

942+
# Remove constant node from reshape node
943+
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
944+
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
945+
928946
# Get the shape of the output of the reshape node
929947
reshape_output_value_info = value_info_map.get(reshape_node_output)
930948
if reshape_output_value_info is not None:
@@ -942,12 +960,17 @@ def quantize_weights_to_int4(
942960
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
943961
scale = scale.reshape(scale_shape)
944962
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
945-
# reshape_node.input = []
946963
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
947964

965+
# Remove unnecessary Cast node
966+
cast_node = reshape_child_nodes[0]
967+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
968+
nodes_to_remove.append(cast_node.name)
969+
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
970+
948971
# Transpose weights and scales if present
949-
if reshape_child_nodes[0].op_type == "Transpose":
950-
transpose_node = reshape_child_nodes[0]
972+
if cast_child_nodes[0].op_type == "Transpose":
973+
transpose_node = cast_child_nodes[0]
951974
nodes_to_remove.append(transpose_node.name)
952975
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
953976
perm = None
@@ -964,7 +987,7 @@ def quantize_weights_to_int4(
964987
)
965988
matmul_node = transpose_child_nodes[0]
966989
else:
967-
matmul_node = reshape_child_nodes[0]
990+
matmul_node = cast_child_nodes[0]
968991
assert matmul_node.op_type in ["MatMul", "Gemm"], (
969992
f"Expected MatMul or Gemm node for {node.name}"
970993
)
@@ -995,9 +1018,24 @@ def quantize_weights_to_int4(
9951018
initializer_map[weight_name].CopyFrom(weights_int4_onnx)
9961019
logger.debug(f"Converted {weight_name} to INT4 precision")
9971020

1021+
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
1022+
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
1023+
return node.op_type == "Mul" and has_pqs_input
1024+
1025+
# Remove unnecessay Cast after Pre-quant scale
1026+
for node in graph.node:
1027+
if is_pre_quant_scale_node(node):
1028+
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
1029+
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
1030+
cast_node = pqs_child_nodes[0]
1031+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
1032+
node.output.clear()
1033+
node.output.extend(cast_node.output)
1034+
nodes_to_remove.append(cast_node.name)
1035+
9981036
# Remove transpose and reshape nodes
9991037
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
1000-
graph.node.clear()
1038+
del graph.node[:]
10011039
graph.node.extend(new_nodes)
10021040

10031041
def is_fp32_cast(node: onnx.NodeProto) -> bool:
@@ -1009,7 +1047,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
10091047
for node in graph.node:
10101048
if node.op_type == "Cast":
10111049
# Skip Cast nodes that are part of normalization layers and outputs
1012-
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast":
1050+
if "norm/Cast" in node.name and is_fp32_cast(node):
10131051
continue
10141052
for attr in node.attribute:
10151053
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
@@ -1104,7 +1142,13 @@ def quantize_weights_to_mxfp8(
11041142
# Expand block array so that it can be broadcasted with weight
11051143
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
11061144
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1107-
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name)
1145+
weights_e4m3 = onnx.helper.make_tensor(
1146+
name=weight_name,
1147+
data_type=onnx_dtype_map["Float8"],
1148+
dims=[*scaled_weight.shape],
1149+
vals=_cast_fp8(scaled_weight).tobytes(),
1150+
raw=True,
1151+
)
11081152
initializer_map[weight_name].CopyFrom(weights_e4m3)
11091153
logger.debug(f"Converted {weight_name} to MXFP8")
11101154

@@ -1186,11 +1230,24 @@ def _add_input_value_info(graph, tensor_proto):
11861230
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
11871231

11881232
# Create TensorProto for initializers
1189-
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name)
1233+
w_f4_proto = onnx.helper.make_tensor(
1234+
name=w_f4_name,
1235+
data_type=onnx_dtype_map["Float4"],
1236+
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
1237+
vals=w_f4.tobytes(),
1238+
raw=True,
1239+
)
11901240
sw_f32_per_tensor_proto = onnx.numpy_helper.from_array(
11911241
sw_f32_per_tensor, sw_f32_per_tensor_name
11921242
)
11931243
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name)
1244+
sw_f8_per_block_proto = onnx.helper.make_tensor(
1245+
name=sw_f8_per_block_name,
1246+
data_type=onnx_dtype_map["Float8"],
1247+
dims=[*sw_f8_per_block.shape],
1248+
vals=sw_f8_per_block.tobytes(),
1249+
raw=True,
1250+
)
11941251

11951252
# Add ValueInfo for the initializers if not present
11961253
_add_input_value_info(graph, w_f4_proto)

0 commit comments

Comments
 (0)