2323import onnx_graphsurgeon as gs
2424import torch
2525from onnx import numpy_helper
26- from onnx .reference .custom_element_types import float8e4m3fn
2726
2827from modelopt .onnx import utils
2928from modelopt .onnx .logging_config import logger
5049onnx_dtype_map = {
5150 "BFloat16" : onnx .TensorProto .BFLOAT16 ,
5251 "Float" : onnx .TensorProto .FLOAT ,
52+ "Float4" : onnx .TensorProto .FLOAT4E2M1 ,
5353 "Float8" : onnx .TensorProto .FLOAT8E4M3FN ,
5454 "Half" : onnx .TensorProto .FLOAT16 ,
5555 "INT8" : onnx .TensorProto .INT8 ,
@@ -592,7 +592,7 @@ def _convert_weight(
592592 zp_array = zp_array .reshape (* reshape_dims )
593593
594594 # Convert to INT8/FP8
595- if zp_array .dtype == float8e4m3fn :
595+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
596596 scaled = np .asarray (weight_array / scale_array ) + zp_array
597597 else :
598598 scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -607,17 +607,22 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607607 if torch .cuda .is_available ():
608608 array_f32_t = array_f32_t .cuda ()
609609 array_f8_t = array_f32_t .clamp (min = - 448 , max = 448 ).to (torch .float8_e4m3fn ).view (torch .uint8 )
610- array_f8 = array_f8_t .cpu ().numpy ().astype (( np .uint8 , [( "e4m3fn" , "u1" )]) )
610+ array_f8 = array_f8_t .cpu ().numpy ().astype (np .uint8 )
611611 return array_f8
612612
613613
614614def _cast_fp4 (array : np .ndarray ) -> np .ndarray :
615615 """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
616616 array_f32_t = torch .from_numpy (array )
617+ array_f32_t_shape = array_f32_t .shape
618+ assert array_f32_t_shape [0 ] % 2 == 0 , "array_f32_t_shape[0] must be divisible by 2"
619+ array_f4_t_shape = (array_f32_t_shape [0 ] // 2 , * array_f32_t_shape [1 :])
617620 if torch .cuda .is_available ():
618621 array_f32_t = array_f32_t .cuda ()
619622 array_f4_t = NVFP4QTensor ._cast_fp4 (array_f32_t )
620- array_f4 = array_f4_t .cpu ().numpy ().astype ((np .uint8 , [("float4e2m1" , "u1" )]))
623+ array_f4_t = array_f4_t .flatten ()
624+ array_f4_t_packed = (array_f4_t [::2 ] | (array_f4_t [1 ::2 ] << 4 )).reshape (array_f4_t_shape )
625+ array_f4 = array_f4_t_packed .cpu ().numpy ().astype (np .uint8 )
621626 return array_f4
622627
623628
@@ -685,7 +690,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685690 scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
686691
687692 # Create and update new weight tensor
688- if zp_array .dtype == float8e4m3fn :
693+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
689694 new_weight = _create_fp8_tensor (scaled , weight_name )
690695 logger .debug (f"Converted { weight_name } to FP8" )
691696 else :
@@ -920,6 +925,10 @@ def quantize_weights_to_int4(
920925 assert reshape_node .op_type == "Reshape" , f"Expected Reshape node for { node .name } "
921926 reshape_node_output = reshape_node .output [0 ]
922927
928+ # Remove constant node from reshape node
929+ shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
930+ nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
931+
923932 # Get the shape of the output of the reshape node
924933 reshape_output_value_info = value_info_map .get (reshape_node_output )
925934 if reshape_output_value_info is not None :
@@ -937,12 +946,17 @@ def quantize_weights_to_int4(
937946 scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
938947 scale = scale .reshape (scale_shape )
939948 reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
940- # reshape_node.input = []
941949 assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
942950
951+ # Remove unnecessary Cast node
952+ cast_node = reshape_child_nodes [0 ]
953+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
954+ nodes_to_remove .append (cast_node .name )
955+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
956+
943957 # Transpose weights and scales if present
944- if reshape_child_nodes [0 ].op_type == "Transpose" :
945- transpose_node = reshape_child_nodes [0 ]
958+ if cast_child_nodes [0 ].op_type == "Transpose" :
959+ transpose_node = cast_child_nodes [0 ]
946960 nodes_to_remove .append (transpose_node .name )
947961 assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
948962 perm = None
@@ -959,7 +973,7 @@ def quantize_weights_to_int4(
959973 )
960974 matmul_node = transpose_child_nodes [0 ]
961975 else :
962- matmul_node = reshape_child_nodes [0 ]
976+ matmul_node = cast_child_nodes [0 ]
963977 assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
964978 f"Expected MatMul or Gemm node for { node .name } "
965979 )
@@ -990,6 +1004,21 @@ def quantize_weights_to_int4(
9901004 initializer_map [weight_name ].CopyFrom (weights_int4_onnx )
9911005 logger .debug (f"Converted { weight_name } to INT4 precision" )
9921006
1007+ def is_pre_quant_scale_node (node : onnx .NodeProto ) -> bool :
1008+ has_pqs_input = any (input for input in node .input if "_pre_quant_scale" in input )
1009+ return node .op_type == "Mul" and has_pqs_input
1010+
1011+ # Remove unnecessay Cast after Pre-quant scale
1012+ for node in graph .node :
1013+ if is_pre_quant_scale_node (node ):
1014+ pqs_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
1015+ assert len (pqs_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
1016+ cast_node = pqs_child_nodes [0 ]
1017+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1018+ node .output .clear ()
1019+ node .output .extend (cast_node .output )
1020+ nodes_to_remove .append (cast_node .name )
1021+
9931022 # Remove transpose and reshape nodes
9941023 new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
9951024 graph .node .clear ()
@@ -1004,7 +1033,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
10041033 for node in graph .node :
10051034 if node .op_type == "Cast" :
10061035 # Skip Cast nodes that are part of normalization layers and outputs
1007- if ( "norm/Cast" in node .name and is_fp32_cast (node )) or node . name == "/Cast" :
1036+ if "norm/Cast" in node .name and is_fp32_cast (node ):
10081037 continue
10091038 for attr in node .attribute :
10101039 if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
@@ -1099,7 +1128,13 @@ def quantize_weights_to_mxfp8(
10991128 # Expand block array so that it can be broadcasted with weight
11001129 se8m0_fp32 = np .repeat (se8m0_fp32 , block_size , axis = quant_axis )
11011130 scaled_weight = weight / np .exp2 (se8m0_fp32 - e8_m0_bias )
1102- weights_e4m3 = onnx .numpy_helper .from_array (_cast_fp8 (scaled_weight ), weight_name )
1131+ weights_e4m3 = onnx .helper .make_tensor (
1132+ name = weight_name ,
1133+ data_type = onnx_dtype_map ["Float8" ],
1134+ dims = [* scaled_weight .shape ],
1135+ vals = _cast_fp8 (scaled_weight ).tobytes (),
1136+ raw = True ,
1137+ )
11031138 initializer_map [weight_name ].CopyFrom (weights_e4m3 )
11041139 logger .debug (f"Converted { weight_name } to MXFP8" )
11051140
@@ -1181,11 +1216,24 @@ def _add_input_value_info(graph, tensor_proto):
11811216 sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
11821217
11831218 # Create TensorProto for initializers
1184- w_f4_proto = onnx .numpy_helper .from_array (w_f4 , w_f4_name )
1219+ w_f4_proto = onnx .helper .make_tensor (
1220+ name = w_f4_name ,
1221+ data_type = onnx_dtype_map ["Float4" ],
1222+ dims = [w_f4 .shape [0 ] * 2 , * w_f4 .shape [1 :]],
1223+ vals = w_f4 .tobytes (),
1224+ raw = True ,
1225+ )
11851226 sw_f32_per_tensor_proto = onnx .numpy_helper .from_array (
11861227 sw_f32_per_tensor , sw_f32_per_tensor_name
11871228 )
11881229 sw_f8_per_block_proto = onnx .numpy_helper .from_array (sw_f8_per_block , sw_f8_per_block_name )
1230+ sw_f8_per_block_proto = onnx .helper .make_tensor (
1231+ name = sw_f8_per_block_name ,
1232+ data_type = onnx_dtype_map ["Float8" ],
1233+ dims = [* sw_f8_per_block .shape ],
1234+ vals = sw_f8_per_block .tobytes (),
1235+ raw = True ,
1236+ )
11891237
11901238 # Add ValueInfo for the initializers if not present
11911239 _add_input_value_info (graph , w_f4_proto )
0 commit comments