Skip to content

Commit 957ce07

Browse files
vishalpandya1990kevalmorabia97
authored andcommitted
Fix DQ1 output type error in DQ1->DQ2 for FP4 weights in NVFP4 model (#513)
## What does this PR do? **Type of change:** Bug Fix **Overview:** - In post-processing after NVFP4 PTQ and ONNX Export, we convert FP4-QDQ into DQ1->DQ2 for FP4 weights of the MatMuls. The output of DQ1 is of the original weight-type (FP16 for FP16 base model) but its scale is in FP32. There is a cast-to-fp16 after DQ2. - In above setting, with FP16 base model weights, DQ1 has x_scale in FP32 but its output type is set to FP16. This hybrid precision mode is not allowed up to opset-21, and thereby it leads to error when run with Onnxruntime. - Note that such hybrid precision mode is allowed in opset-23+ but they are not fully supported with onnxruntime EPs today, and even in future we would want to support opset < 23 too. - So, in this change, setting output of DQ1 to FP32 since its scale is in FP32. There is already a cast-to-fp16 after DQ2 (before Gemm). ## Testing - Checked with trtexec binary and onnxruntime-trt-rtx ep - using sd3.5-medium model, on Windows RTX 5090. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: vipandya <[email protected]>
1 parent cab5e0e commit 957ce07

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,6 @@ def replace_fp4qdq_with_2dq(
13171317
w_f4: np.ndarray,
13181318
sw_f32_per_tensor: np.ndarray,
13191319
sw_f8_per_block: np.ndarray,
1320-
precision_dtype: str,
13211320
block_size: int,
13221321
):
13231322
"""Replaces the given node in the ONNX graph with a subgraph consisting of two DequantizeLinear nodes.
@@ -1331,7 +1330,6 @@ def replace_fp4qdq_with_2dq(
13311330
w_f4: NumPy array for w_f4.
13321331
sw_f32_per_tensor: NumPy array for sw_f32_per_tensor.
13331332
sw_f8_per_block: NumPy array for sw_f8_per_block.
1334-
precision_dtype: The precision of the weights.
13351333
block_size: Block size used in block quantization.
13361334
"""
13371335

@@ -1391,39 +1389,39 @@ def _add_input_value_info(graph, tensor_proto):
13911389
_add_initializer(sw_f32_per_tensor_proto)
13921390
_add_initializer(sw_f8_per_block_proto)
13931391

1394-
# Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f16
1395-
sw_f16_name = weight_name + "_f16_scale"
1392+
# Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f32
1393+
sw_f32_name = weight_name + "_f32_scale"
13961394
dequant1 = onnx.helper.make_node(
13971395
"DequantizeLinear",
13981396
inputs=[sw_f8_per_block_proto.name, sw_f32_per_tensor_proto.name],
1399-
outputs=[sw_f16_name],
1397+
outputs=[sw_f32_name],
14001398
name=weight_name + "_DequantizeLinear",
14011399
)
14021400

1403-
# Create DequantizeLinear_2 node: (w_f4, sw_f16) -> w_16
1404-
w16_name = node.output[0]
1401+
# Create DequantizeLinear_2 node: (w_f4, sw_f32) -> w_32
1402+
w32_name = node.output[0]
14051403
dequant2 = onnx.helper.make_node(
14061404
"DequantizeLinear",
1407-
inputs=[w_f4_proto.name, sw_f16_name],
1408-
outputs=[w16_name],
1405+
inputs=[w_f4_proto.name, sw_f32_name],
1406+
outputs=[w32_name],
14091407
name=weight_name + "_DequantizeLinear_1",
14101408
axis=-1,
14111409
block_size=block_size,
14121410
)
14131411

1414-
# Add value_info for sw_f16
1412+
# Add value_info for sw_f32
14151413
# Assuming sw_f16 has the same shape as sw_f8_per_block
1416-
sw_f16_type_proto = onnx.helper.make_tensor_type_proto(
1417-
elem_type=onnx_dtype_map[precision_dtype], shape=sw_f8_per_block.shape
1414+
sw_f32_type_proto = onnx.helper.make_tensor_type_proto(
1415+
elem_type=onnx_dtype_map["Float"], shape=sw_f8_per_block.shape
14181416
)
1419-
sw_f16_value_info = onnx.helper.make_value_info(name=sw_f16_name, type_proto=sw_f16_type_proto)
1417+
sw_f16_value_info = onnx.helper.make_value_info(name=sw_f32_name, type_proto=sw_f32_type_proto)
14201418
graph.value_info.append(sw_f16_value_info)
14211419

14221420
# Change the data type of w16 (output of 2nd DQ) to model weight precision type
1423-
if w16_name in value_info_map:
1424-
value_info_map[w16_name].type.tensor_type.elem_type = onnx_dtype_map[precision_dtype]
1421+
if w32_name in value_info_map:
1422+
value_info_map[w32_name].type.tensor_type.elem_type = onnx_dtype_map["Float"]
14251423
else:
1426-
raise ValueError(f"ValueInfo for {w16_name} not found.")
1424+
raise ValueError(f"ValueInfo for {w32_name} not found.")
14271425

14281426
# Add the new nodes to the graph
14291427
graph.node.extend([dequant1, dequant2])
@@ -1522,7 +1520,6 @@ def _get_precision_dtype() -> str:
15221520
w_f4,
15231521
sw_f32_per_tensor,
15241522
sw_f8_per_block,
1525-
precision_dtype,
15261523
block_size,
15271524
)
15281525

0 commit comments

Comments
 (0)