2323import onnx_graphsurgeon as gs
2424import torch
2525from onnx import numpy_helper
26- from onnx .reference .custom_element_types import float8e4m3fn
2726
2827from modelopt .onnx import utils
2928from modelopt .onnx .logging_config import logger
5049onnx_dtype_map = {
5150 "BFloat16" : onnx .TensorProto .BFLOAT16 ,
5251 "Float" : onnx .TensorProto .FLOAT ,
52+ "Float4" : onnx .TensorProto .FLOAT4E2M1 ,
5353 "Float8" : onnx .TensorProto .FLOAT8E4M3FN ,
5454 "Half" : onnx .TensorProto .FLOAT16 ,
5555 "INT8" : onnx .TensorProto .INT8 ,
@@ -592,7 +592,7 @@ def _convert_weight(
592592 zp_array = zp_array .reshape (* reshape_dims )
593593
594594 # Convert to INT8/FP8
595- if zp_array .dtype == float8e4m3fn :
595+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
596596 scaled = np .asarray (weight_array / scale_array ) + zp_array
597597 else :
598598 scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607607 if torch .cuda .is_available ():
608608 array_f32_t = array_f32_t .cuda ()
609609 array_f8_t = array_f32_t .clamp (min = - 448 , max = 448 ).to (torch .float8_e4m3fn ).view (torch .uint8 )
610- array_f8 = array_f8_t .cpu ().numpy ().astype (( np .uint8 , [( "e4m3fn" , "u1" )]) )
610+ array_f8 = array_f8_t .cpu ().numpy ().astype (np .uint8 )
611611 return array_f8
612612
613613
614614def _cast_fp4 (array : np .ndarray ) -> np .ndarray :
615- """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
615+ """Cast a numpy array to FLOAT4E2M1 using PyTorch.
616+
617+ Note: The first dimension of the array must be divisible by 2
618+ as two FP4 values are packed into a single byte.
619+ """
616620 array_f32_t = torch .from_numpy (array )
621+ array_f32_t_shape = array_f32_t .shape
622+ assert array_f32_t_shape [0 ] % 2 == 0 , "array_f32_t_shape[0] must be divisible by 2"
623+ array_f4_t_shape = (array_f32_t_shape [0 ] // 2 , * array_f32_t_shape [1 :])
617624 if torch .cuda .is_available ():
618625 array_f32_t = array_f32_t .cuda ()
619626 array_f4_t = NVFP4QTensor ._cast_fp4 (array_f32_t )
620- array_f4 = array_f4_t .cpu ().numpy ().astype ((np .uint8 , [("float4e2m1" , "u1" )]))
627+ array_f4_t = array_f4_t .flatten ()
628+ array_f4_t_packed = (array_f4_t [::2 ] | (array_f4_t [1 ::2 ] << 4 )).reshape (array_f4_t_shape )
629+ array_f4 = array_f4_t_packed .cpu ().numpy ().astype (np .uint8 )
621630 return array_f4
622631
623632
@@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685694 scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
686695
687696 # Create and update new weight tensor
688- if zp_array .dtype == float8e4m3fn :
697+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
689698 new_weight = _create_fp8_tensor (scaled , weight_name )
690699 logger .debug (f"Converted { weight_name } to FP8" )
691700 else :
@@ -925,6 +934,10 @@ def quantize_weights_to_int4(
925934 assert reshape_node .op_type == "Reshape" , f"Expected Reshape node for { node .name } "
926935 reshape_node_output = reshape_node .output [0 ]
927936
937+ # Remove constant node from reshape node
938+ shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
939+ nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
940+
928941 # Get the shape of the output of the reshape node
929942 reshape_output_value_info = value_info_map .get (reshape_node_output )
930943 if reshape_output_value_info is not None :
@@ -942,12 +955,17 @@ def quantize_weights_to_int4(
942955 scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
943956 scale = scale .reshape (scale_shape )
944957 reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
945- # reshape_node.input = []
946958 assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
947959
960+ # Remove unnecessary Cast node
961+ cast_node = reshape_child_nodes [0 ]
962+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
963+ nodes_to_remove .append (cast_node .name )
964+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
965+
948966 # Transpose weights and scales if present
949- if reshape_child_nodes [0 ].op_type == "Transpose" :
950- transpose_node = reshape_child_nodes [0 ]
967+ if cast_child_nodes [0 ].op_type == "Transpose" :
968+ transpose_node = cast_child_nodes [0 ]
951969 nodes_to_remove .append (transpose_node .name )
952970 assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
953971 perm = None
@@ -964,7 +982,7 @@ def quantize_weights_to_int4(
964982 )
965983 matmul_node = transpose_child_nodes [0 ]
966984 else :
967- matmul_node = reshape_child_nodes [0 ]
985+ matmul_node = cast_child_nodes [0 ]
968986 assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
969987 f"Expected MatMul or Gemm node for { node .name } "
970988 )
@@ -995,6 +1013,21 @@ def quantize_weights_to_int4(
9951013 initializer_map [weight_name ].CopyFrom (weights_int4_onnx )
9961014 logger .debug (f"Converted { weight_name } to INT4 precision" )
9971015
1016+ def is_pre_quant_scale_node (node : onnx .NodeProto ) -> bool :
1017+ has_pqs_input = any (input for input in node .input if "_pre_quant_scale" in input )
1018+ return node .op_type == "Mul" and has_pqs_input
1019+
1020+ # Remove unnecessay Cast after Pre-quant scale
1021+ for node in graph .node :
1022+ if is_pre_quant_scale_node (node ):
1023+ pqs_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
1024+ assert len (pqs_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
1025+ cast_node = pqs_child_nodes [0 ]
1026+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1027+ node .output .clear ()
1028+ node .output .extend (cast_node .output )
1029+ nodes_to_remove .append (cast_node .name )
1030+
9981031 # Remove transpose and reshape nodes
9991032 new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
10001033 graph .node .clear ()
@@ -1009,7 +1042,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
10091042 for node in graph .node :
10101043 if node .op_type == "Cast" :
10111044 # Skip Cast nodes that are part of normalization layers and outputs
1012- if ( "norm/Cast" in node .name and is_fp32_cast (node )) or node . name == "/Cast" :
1045+ if "norm/Cast" in node .name and is_fp32_cast (node ):
10131046 continue
10141047 for attr in node .attribute :
10151048 if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
@@ -1104,7 +1137,13 @@ def quantize_weights_to_mxfp8(
11041137 # Expand block array so that it can be broadcasted with weight
11051138 se8m0_fp32 = np .repeat (se8m0_fp32 , block_size , axis = quant_axis )
11061139 scaled_weight = weight / np .exp2 (se8m0_fp32 - e8_m0_bias )
1107- weights_e4m3 = onnx .numpy_helper .from_array (_cast_fp8 (scaled_weight ), weight_name )
1140+ weights_e4m3 = onnx .helper .make_tensor (
1141+ name = weight_name ,
1142+ data_type = onnx_dtype_map ["Float8" ],
1143+ dims = [* scaled_weight .shape ],
1144+ vals = _cast_fp8 (scaled_weight ).tobytes (),
1145+ raw = True ,
1146+ )
11081147 initializer_map [weight_name ].CopyFrom (weights_e4m3 )
11091148 logger .debug (f"Converted { weight_name } to MXFP8" )
11101149
@@ -1186,11 +1225,24 @@ def _add_input_value_info(graph, tensor_proto):
11861225 sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
11871226
11881227 # Create TensorProto for initializers
1189- w_f4_proto = onnx .numpy_helper .from_array (w_f4 , w_f4_name )
1228+ w_f4_proto = onnx .helper .make_tensor (
1229+ name = w_f4_name ,
1230+ data_type = onnx_dtype_map ["Float4" ],
1231+ dims = [w_f4 .shape [0 ] * 2 , * w_f4 .shape [1 :]],
1232+ vals = w_f4 .tobytes (),
1233+ raw = True ,
1234+ )
11901235 sw_f32_per_tensor_proto = onnx .numpy_helper .from_array (
11911236 sw_f32_per_tensor , sw_f32_per_tensor_name
11921237 )
11931238 sw_f8_per_block_proto = onnx .numpy_helper .from_array (sw_f8_per_block , sw_f8_per_block_name )
1239+ sw_f8_per_block_proto = onnx .helper .make_tensor (
1240+ name = sw_f8_per_block_name ,
1241+ data_type = onnx_dtype_map ["Float8" ],
1242+ dims = [* sw_f8_per_block .shape ],
1243+ vals = sw_f8_per_block .tobytes (),
1244+ raw = True ,
1245+ )
11941246
11951247 # Add ValueInfo for the initializers if not present
11961248 _add_input_value_info (graph , w_f4_proto )
0 commit comments