Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,6 @@ def replace_fp4qdq_with_2dq(
w_f4: np.ndarray,
sw_f32_per_tensor: np.ndarray,
sw_f8_per_block: np.ndarray,
precision_dtype: str,
block_size: int,
):
"""Replaces the given node in the ONNX graph with a subgraph consisting of two DequantizeLinear nodes.
Expand All @@ -1346,7 +1345,6 @@ def replace_fp4qdq_with_2dq(
w_f4: NumPy array for w_f4.
sw_f32_per_tensor: NumPy array for sw_f32_per_tensor.
sw_f8_per_block: NumPy array for sw_f8_per_block.
precision_dtype: The precision of the weights.
block_size: Block size used in block quantization.
"""

Expand Down Expand Up @@ -1406,39 +1404,39 @@ def _add_input_value_info(graph, tensor_proto):
_add_initializer(sw_f32_per_tensor_proto)
_add_initializer(sw_f8_per_block_proto)

# Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f16
sw_f16_name = weight_name + "_f16_scale"
# Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f32
sw_f32_name = weight_name + "_f32_scale"
dequant1 = onnx.helper.make_node(
"DequantizeLinear",
inputs=[sw_f8_per_block_proto.name, sw_f32_per_tensor_proto.name],
outputs=[sw_f16_name],
outputs=[sw_f32_name],
name=weight_name + "_DequantizeLinear",
)

# Create DequantizeLinear_2 node: (w_f4, sw_f16) -> w_16
w16_name = node.output[0]
# Create DequantizeLinear_2 node: (w_f4, sw_f32) -> w_32
w32_name = node.output[0]
dequant2 = onnx.helper.make_node(
"DequantizeLinear",
inputs=[w_f4_proto.name, sw_f16_name],
outputs=[w16_name],
inputs=[w_f4_proto.name, sw_f32_name],
outputs=[w32_name],
name=weight_name + "_DequantizeLinear_1",
axis=-1,
block_size=block_size,
)

# Add value_info for sw_f16
# Add value_info for sw_f32
# Assuming sw_f16 has the same shape as sw_f8_per_block
sw_f16_type_proto = onnx.helper.make_tensor_type_proto(
elem_type=onnx_dtype_map[precision_dtype], shape=sw_f8_per_block.shape
sw_f32_type_proto = onnx.helper.make_tensor_type_proto(
elem_type=onnx_dtype_map["Float"], shape=sw_f8_per_block.shape
)
sw_f16_value_info = onnx.helper.make_value_info(name=sw_f16_name, type_proto=sw_f16_type_proto)
sw_f16_value_info = onnx.helper.make_value_info(name=sw_f32_name, type_proto=sw_f32_type_proto)
graph.value_info.append(sw_f16_value_info)

# Change the data type of w16 (output of 2nd DQ) to model weight precision type
if w16_name in value_info_map:
value_info_map[w16_name].type.tensor_type.elem_type = onnx_dtype_map[precision_dtype]
if w32_name in value_info_map:
value_info_map[w32_name].type.tensor_type.elem_type = onnx_dtype_map["Float"]
else:
raise ValueError(f"ValueInfo for {w16_name} not found.")
raise ValueError(f"ValueInfo for {w32_name} not found.")

# Add the new nodes to the graph
graph.node.extend([dequant1, dequant2])
Expand Down Expand Up @@ -1537,7 +1535,6 @@ def _get_precision_dtype() -> str:
w_f4,
sw_f32_per_tensor,
sw_f8_per_block,
precision_dtype,
block_size,
)

Expand Down