23
23
import onnx_graphsurgeon as gs
24
24
import torch
25
25
from onnx import numpy_helper
26
- from onnx .reference .custom_element_types import float8e4m3fn
27
26
28
27
from modelopt .onnx import utils
29
28
from modelopt .onnx .logging_config import logger
50
49
onnx_dtype_map = {
51
50
"BFloat16" : onnx .TensorProto .BFLOAT16 ,
52
51
"Float" : onnx .TensorProto .FLOAT ,
52
+ "Float4" : onnx .TensorProto .FLOAT4E2M1 ,
53
53
"Float8" : onnx .TensorProto .FLOAT8E4M3FN ,
54
54
"Half" : onnx .TensorProto .FLOAT16 ,
55
55
"INT8" : onnx .TensorProto .INT8 ,
@@ -592,7 +592,7 @@ def _convert_weight(
592
592
zp_array = zp_array .reshape (* reshape_dims )
593
593
594
594
# Convert to INT8/FP8
595
- if zp_array .dtype == float8e4m3fn :
595
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
596
596
scaled = np .asarray (weight_array / scale_array ) + zp_array
597
597
else :
598
598
scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -607,17 +607,22 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607
607
if torch .cuda .is_available ():
608
608
array_f32_t = array_f32_t .cuda ()
609
609
array_f8_t = array_f32_t .clamp (min = - 448 , max = 448 ).to (torch .float8_e4m3fn ).view (torch .uint8 )
610
- array_f8 = array_f8_t .cpu ().numpy ().astype (( np .uint8 , [( "e4m3fn" , "u1" )]) )
610
+ array_f8 = array_f8_t .cpu ().numpy ().astype (np .uint8 )
611
611
return array_f8
612
612
613
613
614
614
def _cast_fp4 (array : np .ndarray ) -> np .ndarray :
615
615
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
616
616
array_f32_t = torch .from_numpy (array )
617
+ array_f32_t_shape = array_f32_t .shape
618
+ assert array_f32_t_shape [0 ] % 2 == 0 , "array_f32_t_shape[0] must be divisible by 2"
619
+ array_f4_t_shape = (array_f32_t_shape [0 ] // 2 , * array_f32_t_shape [1 :])
617
620
if torch .cuda .is_available ():
618
621
array_f32_t = array_f32_t .cuda ()
619
622
array_f4_t = NVFP4QTensor ._cast_fp4 (array_f32_t )
620
- array_f4 = array_f4_t .cpu ().numpy ().astype ((np .uint8 , [("float4e2m1" , "u1" )]))
623
+ array_f4_t = array_f4_t .flatten ()
624
+ array_f4_t_packed = (array_f4_t [::2 ] | (array_f4_t [1 ::2 ] << 4 )).reshape (array_f4_t_shape )
625
+ array_f4 = array_f4_t_packed .cpu ().numpy ().astype (np .uint8 )
621
626
return array_f4
622
627
623
628
@@ -685,7 +690,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685
690
scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
686
691
687
692
# Create and update new weight tensor
688
- if zp_array .dtype == float8e4m3fn :
693
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
689
694
new_weight = _create_fp8_tensor (scaled , weight_name )
690
695
logger .debug (f"Converted { weight_name } to FP8" )
691
696
else :
@@ -920,6 +925,10 @@ def quantize_weights_to_int4(
920
925
assert reshape_node .op_type == "Reshape" , f"Expected Reshape node for { node .name } "
921
926
reshape_node_output = reshape_node .output [0 ]
922
927
928
+ # Remove constant node from reshape node
929
+ shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
930
+ nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
931
+
923
932
# Get the shape of the output of the reshape node
924
933
reshape_output_value_info = value_info_map .get (reshape_node_output )
925
934
if reshape_output_value_info is not None :
@@ -937,12 +946,17 @@ def quantize_weights_to_int4(
937
946
scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
938
947
scale = scale .reshape (scale_shape )
939
948
reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
940
- # reshape_node.input = []
941
949
assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
942
950
951
+ # Remove unnecessary Cast node
952
+ cast_node = reshape_child_nodes [0 ]
953
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
954
+ nodes_to_remove .append (cast_node .name )
955
+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
956
+
943
957
# Transpose weights and scales if present
944
- if reshape_child_nodes [0 ].op_type == "Transpose" :
945
- transpose_node = reshape_child_nodes [0 ]
958
+ if cast_child_nodes [0 ].op_type == "Transpose" :
959
+ transpose_node = cast_child_nodes [0 ]
946
960
nodes_to_remove .append (transpose_node .name )
947
961
assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
948
962
perm = None
@@ -959,7 +973,7 @@ def quantize_weights_to_int4(
959
973
)
960
974
matmul_node = transpose_child_nodes [0 ]
961
975
else :
962
- matmul_node = reshape_child_nodes [0 ]
976
+ matmul_node = cast_child_nodes [0 ]
963
977
assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
964
978
f"Expected MatMul or Gemm node for { node .name } "
965
979
)
@@ -990,6 +1004,21 @@ def quantize_weights_to_int4(
990
1004
initializer_map [weight_name ].CopyFrom (weights_int4_onnx )
991
1005
logger .debug (f"Converted { weight_name } to INT4 precision" )
992
1006
1007
+ def is_pre_quant_scale_node (node : onnx .NodeProto ) -> bool :
1008
+ has_pqs_input = any (input for input in node .input if "_pre_quant_scale" in input )
1009
+ return node .op_type == "Mul" and has_pqs_input
1010
+
1011
+ # Remove unnecessay Cast after Pre-quant scale
1012
+ for node in graph .node :
1013
+ if is_pre_quant_scale_node (node ):
1014
+ pqs_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
1015
+ assert len (pqs_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
1016
+ cast_node = pqs_child_nodes [0 ]
1017
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1018
+ node .output .clear ()
1019
+ node .output .extend (cast_node .output )
1020
+ nodes_to_remove .append (cast_node .name )
1021
+
993
1022
# Remove transpose and reshape nodes
994
1023
new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
995
1024
graph .node .clear ()
@@ -1004,7 +1033,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
1004
1033
for node in graph .node :
1005
1034
if node .op_type == "Cast" :
1006
1035
# Skip Cast nodes that are part of normalization layers and outputs
1007
- if ( "norm/Cast" in node .name and is_fp32_cast (node )) or node . name == "/Cast" :
1036
+ if "norm/Cast" in node .name and is_fp32_cast (node ):
1008
1037
continue
1009
1038
for attr in node .attribute :
1010
1039
if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
@@ -1099,7 +1128,13 @@ def quantize_weights_to_mxfp8(
1099
1128
# Expand block array so that it can be broadcasted with weight
1100
1129
se8m0_fp32 = np .repeat (se8m0_fp32 , block_size , axis = quant_axis )
1101
1130
scaled_weight = weight / np .exp2 (se8m0_fp32 - e8_m0_bias )
1102
- weights_e4m3 = onnx .numpy_helper .from_array (_cast_fp8 (scaled_weight ), weight_name )
1131
+ weights_e4m3 = onnx .helper .make_tensor (
1132
+ name = weight_name ,
1133
+ data_type = onnx_dtype_map ["Float8" ],
1134
+ dims = [* scaled_weight .shape ],
1135
+ vals = _cast_fp8 (scaled_weight ).tobytes (),
1136
+ raw = True ,
1137
+ )
1103
1138
initializer_map [weight_name ].CopyFrom (weights_e4m3 )
1104
1139
logger .debug (f"Converted { weight_name } to MXFP8" )
1105
1140
@@ -1181,11 +1216,24 @@ def _add_input_value_info(graph, tensor_proto):
1181
1216
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
1182
1217
1183
1218
# Create TensorProto for initializers
1184
- w_f4_proto = onnx .numpy_helper .from_array (w_f4 , w_f4_name )
1219
+ w_f4_proto = onnx .helper .make_tensor (
1220
+ name = w_f4_name ,
1221
+ data_type = onnx_dtype_map ["Float4" ],
1222
+ dims = [w_f4 .shape [0 ] * 2 , * w_f4 .shape [1 :]],
1223
+ vals = w_f4 .tobytes (),
1224
+ raw = True ,
1225
+ )
1185
1226
sw_f32_per_tensor_proto = onnx .numpy_helper .from_array (
1186
1227
sw_f32_per_tensor , sw_f32_per_tensor_name
1187
1228
)
1188
1229
sw_f8_per_block_proto = onnx .numpy_helper .from_array (sw_f8_per_block , sw_f8_per_block_name )
1230
+ sw_f8_per_block_proto = onnx .helper .make_tensor (
1231
+ name = sw_f8_per_block_name ,
1232
+ data_type = onnx_dtype_map ["Float8" ],
1233
+ dims = [* sw_f8_per_block .shape ],
1234
+ vals = sw_f8_per_block .tobytes (),
1235
+ raw = True ,
1236
+ )
1189
1237
1190
1238
# Add ValueInfo for the initializers if not present
1191
1239
_add_input_value_info (graph , w_f4_proto )
0 commit comments