diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index c3edc967f..2d26e3b0c 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1332,7 +1332,6 @@ def replace_fp4qdq_with_2dq( w_f4: np.ndarray, sw_f32_per_tensor: np.ndarray, sw_f8_per_block: np.ndarray, - precision_dtype: str, block_size: int, ): """Replaces the given node in the ONNX graph with a subgraph consisting of two DequantizeLinear nodes. @@ -1346,7 +1345,6 @@ def replace_fp4qdq_with_2dq( w_f4: NumPy array for w_f4. sw_f32_per_tensor: NumPy array for sw_f32_per_tensor. sw_f8_per_block: NumPy array for sw_f8_per_block. - precision_dtype: The precision of the weights. block_size: Block size used in block quantization. """ @@ -1406,39 +1404,39 @@ def _add_input_value_info(graph, tensor_proto): _add_initializer(sw_f32_per_tensor_proto) _add_initializer(sw_f8_per_block_proto) - # Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f16 - sw_f16_name = weight_name + "_f16_scale" + # Create DequantizeLinear_1 node: (sw_f8_per_block, sw_f32_per_tensor) -> sw_f32 + sw_f32_name = weight_name + "_f32_scale" dequant1 = onnx.helper.make_node( "DequantizeLinear", inputs=[sw_f8_per_block_proto.name, sw_f32_per_tensor_proto.name], - outputs=[sw_f16_name], + outputs=[sw_f32_name], name=weight_name + "_DequantizeLinear", ) - # Create DequantizeLinear_2 node: (w_f4, sw_f16) -> w_16 - w16_name = node.output[0] + # Create DequantizeLinear_2 node: (w_f4, sw_f32) -> w_32 + w32_name = node.output[0] dequant2 = onnx.helper.make_node( "DequantizeLinear", - inputs=[w_f4_proto.name, sw_f16_name], - outputs=[w16_name], + inputs=[w_f4_proto.name, sw_f32_name], + outputs=[w32_name], name=weight_name + "_DequantizeLinear_1", axis=-1, block_size=block_size, ) - # Add value_info for sw_f16 + # Add value_info for sw_f32 # Assuming sw_f16 has the same shape as sw_f8_per_block - sw_f16_type_proto = onnx.helper.make_tensor_type_proto( - elem_type=onnx_dtype_map[precision_dtype], shape=sw_f8_per_block.shape + sw_f32_type_proto = onnx.helper.make_tensor_type_proto( + elem_type=onnx_dtype_map["Float"], shape=sw_f8_per_block.shape ) - sw_f16_value_info = onnx.helper.make_value_info(name=sw_f16_name, type_proto=sw_f16_type_proto) + sw_f16_value_info = onnx.helper.make_value_info(name=sw_f32_name, type_proto=sw_f32_type_proto) graph.value_info.append(sw_f16_value_info) # Change the data type of w16 (output of 2nd DQ) to model weight precision type - if w16_name in value_info_map: - value_info_map[w16_name].type.tensor_type.elem_type = onnx_dtype_map[precision_dtype] + if w32_name in value_info_map: + value_info_map[w32_name].type.tensor_type.elem_type = onnx_dtype_map["Float"] else: - raise ValueError(f"ValueInfo for {w16_name} not found.") + raise ValueError(f"ValueInfo for {w32_name} not found.") # Add the new nodes to the graph graph.node.extend([dequant1, dequant2]) @@ -1537,7 +1535,6 @@ def _get_precision_dtype() -> str: w_f4, sw_f32_per_tensor, sw_f8_per_block, - precision_dtype, block_size, )