23
23
import onnx_graphsurgeon as gs
24
24
import torch
25
25
from onnx import numpy_helper
26
- from onnx .reference .custom_element_types import float8e4m3fn
27
26
28
27
from modelopt .onnx import utils
29
28
from modelopt .onnx .logging_config import logger
50
49
onnx_dtype_map = {
51
50
"BFloat16" : onnx .TensorProto .BFLOAT16 ,
52
51
"Float" : onnx .TensorProto .FLOAT ,
52
+ "Float4" : onnx .TensorProto .FLOAT4E2M1 ,
53
53
"Float8" : onnx .TensorProto .FLOAT8E4M3FN ,
54
54
"Half" : onnx .TensorProto .FLOAT16 ,
55
55
"INT8" : onnx .TensorProto .INT8 ,
@@ -592,7 +592,7 @@ def _convert_weight(
592
592
zp_array = zp_array .reshape (* reshape_dims )
593
593
594
594
# Convert to INT8/FP8
595
- if zp_array .dtype == float8e4m3fn :
595
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
596
596
scaled = np .asarray (weight_array / scale_array ) + zp_array
597
597
else :
598
598
scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607
607
if torch .cuda .is_available ():
608
608
array_f32_t = array_f32_t .cuda ()
609
609
array_f8_t = array_f32_t .clamp (min = - 448 , max = 448 ).to (torch .float8_e4m3fn ).view (torch .uint8 )
610
- array_f8 = array_f8_t .cpu ().numpy ().astype (( np .uint8 , [( "e4m3fn" , "u1" )]) )
610
+ array_f8 = array_f8_t .cpu ().numpy ().astype (np .uint8 )
611
611
return array_f8
612
612
613
613
614
614
def _cast_fp4 (array : np .ndarray ) -> np .ndarray :
615
- """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
615
+ """Cast a numpy array to FLOAT4E2M1 using PyTorch.
616
+
617
+ Note: The first dimension of the array must be divisible by 2
618
+ as two FP4 values are packed into a single byte.
619
+ """
616
620
array_f32_t = torch .from_numpy (array )
621
+ array_f32_t_shape = array_f32_t .shape
622
+ assert array_f32_t_shape [0 ] % 2 == 0 , "array_f32_t_shape[0] must be divisible by 2"
623
+ array_f4_t_shape = (array_f32_t_shape [0 ] // 2 , * array_f32_t_shape [1 :])
617
624
if torch .cuda .is_available ():
618
625
array_f32_t = array_f32_t .cuda ()
619
626
array_f4_t = NVFP4QTensor ._cast_fp4 (array_f32_t )
620
- array_f4 = array_f4_t .cpu ().numpy ().astype ((np .uint8 , [("float4e2m1" , "u1" )]))
627
+ array_f4_t = array_f4_t .flatten ()
628
+ array_f4_t_packed = (array_f4_t [::2 ] | (array_f4_t [1 ::2 ] << 4 )).reshape (array_f4_t_shape )
629
+ array_f4 = array_f4_t_packed .cpu ().numpy ().astype (np .uint8 )
621
630
return array_f4
622
631
623
632
@@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685
694
scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
686
695
687
696
# Create and update new weight tensor
688
- if zp_array .dtype == float8e4m3fn :
697
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
689
698
new_weight = _create_fp8_tensor (scaled , weight_name )
690
699
logger .debug (f"Converted { weight_name } to FP8" )
691
700
else :
@@ -925,6 +934,10 @@ def quantize_weights_to_int4(
925
934
assert reshape_node .op_type == "Reshape" , f"Expected Reshape node for { node .name } "
926
935
reshape_node_output = reshape_node .output [0 ]
927
936
937
+ # Remove constant node from reshape node
938
+ shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
939
+ nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
940
+
928
941
# Get the shape of the output of the reshape node
929
942
reshape_output_value_info = value_info_map .get (reshape_node_output )
930
943
if reshape_output_value_info is not None :
@@ -942,12 +955,17 @@ def quantize_weights_to_int4(
942
955
scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
943
956
scale = scale .reshape (scale_shape )
944
957
reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
945
- # reshape_node.input = []
946
958
assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
947
959
960
+ # Remove unnecessary Cast node
961
+ cast_node = reshape_child_nodes [0 ]
962
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
963
+ nodes_to_remove .append (cast_node .name )
964
+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
965
+
948
966
# Transpose weights and scales if present
949
- if reshape_child_nodes [0 ].op_type == "Transpose" :
950
- transpose_node = reshape_child_nodes [0 ]
967
+ if cast_child_nodes [0 ].op_type == "Transpose" :
968
+ transpose_node = cast_child_nodes [0 ]
951
969
nodes_to_remove .append (transpose_node .name )
952
970
assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
953
971
perm = None
@@ -964,7 +982,7 @@ def quantize_weights_to_int4(
964
982
)
965
983
matmul_node = transpose_child_nodes [0 ]
966
984
else :
967
- matmul_node = reshape_child_nodes [0 ]
985
+ matmul_node = cast_child_nodes [0 ]
968
986
assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
969
987
f"Expected MatMul or Gemm node for { node .name } "
970
988
)
@@ -995,6 +1013,21 @@ def quantize_weights_to_int4(
995
1013
initializer_map [weight_name ].CopyFrom (weights_int4_onnx )
996
1014
logger .debug (f"Converted { weight_name } to INT4 precision" )
997
1015
1016
+ def is_pre_quant_scale_node (node : onnx .NodeProto ) -> bool :
1017
+ has_pqs_input = any (input for input in node .input if "_pre_quant_scale" in input )
1018
+ return node .op_type == "Mul" and has_pqs_input
1019
+
1020
+ # Remove unnecessay Cast after Pre-quant scale
1021
+ for node in graph .node :
1022
+ if is_pre_quant_scale_node (node ):
1023
+ pqs_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
1024
+ assert len (pqs_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
1025
+ cast_node = pqs_child_nodes [0 ]
1026
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1027
+ node .output .clear ()
1028
+ node .output .extend (cast_node .output )
1029
+ nodes_to_remove .append (cast_node .name )
1030
+
998
1031
# Remove transpose and reshape nodes
999
1032
new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
1000
1033
graph .node .clear ()
@@ -1009,7 +1042,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
1009
1042
for node in graph .node :
1010
1043
if node .op_type == "Cast" :
1011
1044
# Skip Cast nodes that are part of normalization layers and outputs
1012
- if ( "norm/Cast" in node .name and is_fp32_cast (node )) or node . name == "/Cast" :
1045
+ if "norm/Cast" in node .name and is_fp32_cast (node ):
1013
1046
continue
1014
1047
for attr in node .attribute :
1015
1048
if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
@@ -1104,7 +1137,13 @@ def quantize_weights_to_mxfp8(
1104
1137
# Expand block array so that it can be broadcasted with weight
1105
1138
se8m0_fp32 = np .repeat (se8m0_fp32 , block_size , axis = quant_axis )
1106
1139
scaled_weight = weight / np .exp2 (se8m0_fp32 - e8_m0_bias )
1107
- weights_e4m3 = onnx .numpy_helper .from_array (_cast_fp8 (scaled_weight ), weight_name )
1140
+ weights_e4m3 = onnx .helper .make_tensor (
1141
+ name = weight_name ,
1142
+ data_type = onnx_dtype_map ["Float8" ],
1143
+ dims = [* scaled_weight .shape ],
1144
+ vals = _cast_fp8 (scaled_weight ).tobytes (),
1145
+ raw = True ,
1146
+ )
1108
1147
initializer_map [weight_name ].CopyFrom (weights_e4m3 )
1109
1148
logger .debug (f"Converted { weight_name } to MXFP8" )
1110
1149
@@ -1186,11 +1225,24 @@ def _add_input_value_info(graph, tensor_proto):
1186
1225
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
1187
1226
1188
1227
# Create TensorProto for initializers
1189
- w_f4_proto = onnx .numpy_helper .from_array (w_f4 , w_f4_name )
1228
+ w_f4_proto = onnx .helper .make_tensor (
1229
+ name = w_f4_name ,
1230
+ data_type = onnx_dtype_map ["Float4" ],
1231
+ dims = [w_f4 .shape [0 ] * 2 , * w_f4 .shape [1 :]],
1232
+ vals = w_f4 .tobytes (),
1233
+ raw = True ,
1234
+ )
1190
1235
sw_f32_per_tensor_proto = onnx .numpy_helper .from_array (
1191
1236
sw_f32_per_tensor , sw_f32_per_tensor_name
1192
1237
)
1193
1238
sw_f8_per_block_proto = onnx .numpy_helper .from_array (sw_f8_per_block , sw_f8_per_block_name )
1239
+ sw_f8_per_block_proto = onnx .helper .make_tensor (
1240
+ name = sw_f8_per_block_name ,
1241
+ data_type = onnx_dtype_map ["Float8" ],
1242
+ dims = [* sw_f8_per_block .shape ],
1243
+ vals = sw_f8_per_block .tobytes (),
1244
+ raw = True ,
1245
+ )
1194
1246
1195
1247
# Add ValueInfo for the initializers if not present
1196
1248
_add_input_value_info (graph , w_f4_proto )
0 commit comments