Skip to content

Commit 6abded4

Browse files
authored
[5455919] Fix Q/DQ/Cast placement in 'FP32 required' custom ops (NVIDIA#554)
## What does this PR do? **Type of change:** Bug fix **Overview:** Fix incorrect quantization of custom ops when some input tensors are required to be in INT8 and some in FP32. | Before fix | After fix | |----------------|-------------| | <img width="841" height="623" alt="snap_custom_op_quant_incorrect" src="https://github.com/user-attachments/assets/88e4d460-fbae-4bcb-86c8-139d23ce04c8" /> | <img width="786" height="286" alt="snap_custom_op_quant_correct" src="https://github.com/user-attachments/assets/475079c2-a565-4f0d-b167-6d801ab83dfc" /> | ## Usage ```python $ python -m modelopt.onnx.quantization --onnx_path=$MODEL_PATH.onnx \ --trt_plugins $PLUGIN_PATH.so \ --trt_plugins_precision $CUSTOM_OP_NAME:$PRECISION ``` ## Testing ### 1. BEVFormer model - Follow step 1 in [README](https://github.com/NVIDIA/DL4AGX/tree/master/AV-Solutions/bevformer-int8-eq#1-export-model-to-onnx-and-compile-plugins). - In the quantization step, do: ```sh $ python -m modelopt.onnx.quantization --onnx_path=/mnt/models/bevformer_tiny_epoch_24_cp2_op13.onnx \ --trt_plugins=$PLUGIN_PATH \ --trt_plugins_precision MultiScaleDeformableAttnTRT:[int8,int32,fp32,int8,int8]:[int8] \ --high_precision_dtype fp16 ``` > See table in "Overview" for expected graph structure. ### 2. 5455919 model Validated model in bug 5455919. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information - NVIDIA/pull/363: Feature expansion. - NVIDIA/pull/524: The graph cleanup is actually needed after Q/DQ trimming around custom ops. Moved the cleanup lines to inside that function. --------- Signed-off-by: gcunhase <[email protected]>
1 parent e20d218 commit 6abded4

File tree

7 files changed

+77
-41
lines changed

7 files changed

+77
-41
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ Model Optimizer Changelog (Linux)
2020
**Bug Fixes**
2121

2222
- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
23+
- Fix Q/DQ/Cast node placements in 'FP32 required' tensors in custom ops in the ONNX quantization workflow.
2324

2425
**New Features**
2526

2627
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
2728
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
2829
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
30+
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
2931

3032

3133
0.39 (2025-11-11)

modelopt/onnx/autocast/convert.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def convert_to_f16(
194194
low_precision_type: str = "fp16",
195195
keep_io_types: bool = True,
196196
op_block_list: list[str] = [],
197+
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
197198
trt_plugins: list[str] | None = [],
198199
) -> onnx.ModelProto:
199200
"""Convert model to mixed precision, using PrecisionConverter.
@@ -204,8 +205,8 @@ def convert_to_f16(
204205
model: ONNX model to convert.
205206
low_precision_type: Target precision to reduce to ('fp16' or 'bf16').
206207
keep_io_types: Whether to preserve input/output types.
207-
disable_shape_infer: Whether to disable shape inference.
208208
op_block_list: List of operation types that should remain in FP32.
209+
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
209210
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
210211
"""
211212
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
@@ -235,6 +236,7 @@ def convert_to_f16(
235236
keep_io_types=keep_io_types,
236237
low_precision_type=low_precision_type,
237238
custom_ops=sanitizer.custom_ops,
239+
tensor_block_dict=tensor_block_dict,
238240
)
239241
high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list]
240242
low_precision_nodes = [

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
min_opset: int = 13,
100100
max_ir_version: int | None = None,
101101
trt_plugins: list[str] | None = [],
102+
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
102103
) -> None:
103104
"""Initialize PrecisionConverter.
104105
@@ -112,6 +113,10 @@ def __init__(
112113
init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
113114
cast at runtime.
114115
custom_ops: List of custom ops.
116+
min_opset: Minimum opset for conversion.
117+
max_ir_version: Max IR version for conversion.
118+
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
119+
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
115120
"""
116121
self.model = deepcopy(model)
117122
self.value_info_map = value_info_map
@@ -148,6 +153,9 @@ def __init__(
148153
)
149154
)
150155

156+
# Custom mapping of op types to indices of inputs that should not be converted to low precision
157+
self.skip_inputs_map = self._create_skip_inputs_mapping(tensor_block_dict)
158+
151159
def convert(
152160
self,
153161
high_precision_nodes: list[str],
@@ -211,7 +219,8 @@ def convert(
211219
# For the low precision nodes that take a FP32 input, we don't exclude it from
212220
# casting up so that the input can be converted to FP32 as expected.
213221
exclude_consumers = list(
214-
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
222+
set(low_precision_nodes)
223+
- {n.name for n in fp32_input_to_low_precision_node[tensor_name]}
215224
)
216225
self._add_cast(
217226
tensor_name,
@@ -467,12 +476,14 @@ def _filter_unsupported_op_types(
467476
return high_precision_nodes, low_precision_nodes
468477

469478
def _get_tensors_to_cast(
470-
self, low_precision_nodes: list[str]
471-
) -> tuple[list[str], list[str], dict[str, onnx.NodeProto]]:
479+
self,
480+
low_precision_nodes: list[str],
481+
high_precision_tensors: dict[str, dict[str, list[int]]] = {},
482+
) -> tuple[list[str], list[str], dict[str, list[onnx.NodeProto]]]:
472483
cast_to_fp16 = [] # Tensors to cast down to FP16
473484
cast_to_fp32 = [] # Tensors to cast up to FP32
474485
# Keep track of the low precision nodes that take a FP32 input.
475-
fp32_input_to_low_precision_node = {}
486+
fp32_input_to_low_precision_node = defaultdict(list)
476487

477488
# Get tensors for FP16 nodes
478489
for node in self.model.graph.node:
@@ -481,7 +492,7 @@ def _get_tensors_to_cast(
481492
for input in node.input:
482493
if self._should_skip_low_precision_input_conversion(node, input):
483494
cast_to_fp32.append(input)
484-
fp32_input_to_low_precision_node[input] = node
495+
fp32_input_to_low_precision_node[input].append(node)
485496
else:
486497
cast_to_fp16.append(input)
487498

@@ -536,7 +547,7 @@ def _convert_initializers(
536547
low_precision_nodes: List of node names that should use low precision initializers.
537548
high_precision_nodes: List of node names that should use high precision initializers.
538549
"""
539-
# 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them.
550+
# 1. Compute a mapping from initializers to high precision nodes & low precision nodes that use them.
540551
low_precision_nodes_set: set[str] = set(low_precision_nodes)
541552
high_precision_nodes_set: set[str] = set(high_precision_nodes)
542553
initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict(
@@ -888,7 +899,7 @@ def _add_cast(
888899
)
889900

890901
if tensor_to_consumers is None:
891-
utils.get_consumer_nodes(self.model, tensor_name)
902+
consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name)
892903
else:
893904
consumer_nodes = tensor_to_consumers.get(tensor_name, [])
894905
consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers]
@@ -1272,13 +1283,9 @@ def _sanitize_model(self):
12721283
graph_sanitizer.sanitize()
12731284
self.model = graph_sanitizer.model
12741285

1275-
def _should_skip_low_precision_input_conversion(
1276-
self, node: onnx.NodeProto, input_name: str
1277-
) -> bool:
1278-
"""Check if the input should be skipped for low precision conversion.
1279-
1280-
This is used for nodes that have inputs that MUST remain in FP32.
1281-
"""
1286+
def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}):
1287+
"""Create mapping of op types to indices of inputs that should not be converted to low precision."""
1288+
skip_inputs_map = {}
12821289
match self.low_precision_type.str_short:
12831290
case "fp16":
12841291
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
@@ -1287,12 +1294,27 @@ def _should_skip_low_precision_input_conversion(
12871294
case _:
12881295
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")
12891296

1290-
if node.op_type in skip_inputs_map:
1297+
# Update mapping with user-defined information
1298+
for op, tensor_map in tensor_block_dict.items():
1299+
high_precision_tensor = tensor_map.get("inp", [])
1300+
if high_precision_tensor:
1301+
skip_inputs_map.update({op: set(high_precision_tensor)})
1302+
1303+
return skip_inputs_map
1304+
1305+
def _should_skip_low_precision_input_conversion(
1306+
self, node: onnx.NodeProto, input_name: str
1307+
) -> bool:
1308+
"""Check if the input should be skipped for low precision conversion.
1309+
1310+
This is used for nodes that have inputs that MUST remain in FP32.
1311+
"""
1312+
if node.op_type in self.skip_inputs_map:
12911313
# Figure out the index of the input in the node input
12921314
inputs_lst = list(node.input)
12931315
if input_name not in inputs_lst:
12941316
raise ValueError(f"Input {input_name} not found in node {node.name}.")
12951317
input_index = inputs_lst.index(input_name)
12961318
# Check if we should skip this input for low precision conversion
1297-
return input_index in skip_inputs_map[node.op_type]
1319+
return input_index in self.skip_inputs_map[node.op_type]
12981320
return False

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def quantize(
169169
op_types_to_quantize: list[str] | None = None,
170170
op_types_to_exclude: list[str] | None = None,
171171
op_types_to_exclude_fp16: list[str] | None = None,
172+
custom_ops_to_cast_fp32: dict | None = None,
172173
nodes_to_quantize: list[str] | None = None,
173174
nodes_to_exclude: list[str] | None = None,
174175
use_external_data_format: bool = False,
@@ -324,6 +325,7 @@ def quantize(
324325
onnx_model,
325326
keep_io_types=not direct_io_types,
326327
op_block_list=op_types_to_exclude_fp16 or [],
328+
tensor_block_dict=custom_ops_to_cast_fp32 or {},
327329
low_precision_type=high_precision_dtype,
328330
trt_plugins=trt_extra_plugin_lib_paths,
329331
)

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def quantize(
120120
op_types_to_quantize: list[str] | None = None,
121121
op_types_to_exclude: list[str] | None = None,
122122
op_types_to_exclude_fp16: list[str] | None = None,
123+
custom_ops_to_cast_fp32: dict | None = None,
123124
nodes_to_quantize: list[str] | None = None,
124125
nodes_to_exclude: list[str] | None = None,
125126
use_external_data_format: bool = False,
@@ -285,6 +286,7 @@ def quantize(
285286
onnx_model,
286287
keep_io_types=not direct_io_types,
287288
op_block_list=op_types_to_exclude_fp16 or [],
289+
tensor_block_dict=custom_ops_to_cast_fp32 or {},
288290
low_precision_type=high_precision_dtype,
289291
trt_plugins=trt_extra_plugin_lib_paths,
290292
)

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -872,22 +872,32 @@ def remove_input_dq_and_output_q(
872872
)
873873

874874
# Only remove DQs from the inputs of custom ops
875-
if consumers[0].op_type not in quantizable_custom_ops:
875+
has_cast = consumers[0].op_type == "Cast"
876+
consumers_2 = tensor_consumers[consumers[0].output[0]] if has_cast else consumers
877+
if consumers_2[0].op_type not in quantizable_custom_ops:
876878
continue
877879

878-
# Rewire graph to connect Q with the node after DQ (skip DQ)
879-
for consumer in consumers:
880-
for cons_idx, cons_inp in enumerate(consumer.input):
881-
if cons_inp == node.output[0]:
882-
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
883-
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
884-
consumer.input[cons_idx] = q_node.output[0]
885-
else:
886-
q_node_prev = tensor_producers.get(q_node.input[0], None)
887-
consumer.input[cons_idx] = (
888-
q_node_prev.output[0] if q_node_prev else q_node.input[0]
889-
)
890-
break
880+
if has_cast:
881+
# Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ
882+
# and the custom op. Keep the Cast node and delete both Q/DQ nodes.
883+
q_node_prev = tensor_producers.get(q_node.input[0], None)
884+
consumers[0].input[0] = (
885+
q_node_prev.output[0] if q_node_prev else q_node.input[0]
886+
)
887+
else:
888+
# Rewire graph to connect Q with the node after DQ (skip DQ)
889+
for consumer in consumers:
890+
for cons_idx, cons_inp in enumerate(consumer.input):
891+
if cons_inp == node.output[0]:
892+
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
893+
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
894+
consumer.input[cons_idx] = q_node.output[0]
895+
else:
896+
q_node_prev = tensor_producers.get(q_node.input[0], None)
897+
consumer.input[cons_idx] = (
898+
q_node_prev.output[0] if q_node_prev else q_node.input[0]
899+
)
900+
break
891901

892902
# Track DequantizeLinear node indices for cleanup
893903
dq_indices.append(node_idx)
@@ -944,6 +954,11 @@ def remove_input_dq_and_output_q(
944954
f" {len(dq_indices)} DQ node{'' if len(dq_indices) == 1 else 's'}"
945955
)
946956

957+
# Cleanup graph to remove any dangling Q/DQ nodes
958+
graph = gs.import_onnx(onnx_model)
959+
graph.cleanup()
960+
onnx_model = gs.export_onnx(graph)
961+
947962
# TODO: remove manual ir_version change once ORT supports ir_version 11
948963
onnx_model.ir_version = 10
949964

modelopt/onnx/quantization/quantize.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -430,16 +430,6 @@ def quantize(
430430
)
431431
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
432432

433-
# Update list with op types to exclude from FP16/BF16 conversion
434-
op_types_to_exclude_fp16 = list(
435-
dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
436-
)
437-
if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
438-
logger.warning(
439-
"Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
440-
"Since the model won't be converted to a lower precision, this flag is void."
441-
)
442-
443433
# Use random scales if calibration data is not supplied
444434
if calibration_data is None:
445435
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
@@ -485,6 +475,7 @@ def quantize(
485475
op_types_to_quantize=op_types_to_quantize,
486476
op_types_to_exclude=op_types_to_exclude,
487477
op_types_to_exclude_fp16=op_types_to_exclude_fp16,
478+
custom_ops_to_cast_fp32=custom_ops_to_cast_fp32,
488479
nodes_to_quantize=nodes_to_quantize,
489480
nodes_to_exclude=nodes_to_exclude,
490481
use_external_data_format=use_external_data_format,

0 commit comments

Comments
 (0)