23
23
import onnx_graphsurgeon as gs
24
24
import torch
25
25
from onnx import numpy_helper
26
- from onnx .reference .custom_element_types import float8e4m3fn
27
26
28
27
from modelopt .onnx import utils
29
28
from modelopt .onnx .logging_config import logger
50
49
onnx_dtype_map = {
51
50
"BFloat16" : onnx .TensorProto .BFLOAT16 ,
52
51
"Float" : onnx .TensorProto .FLOAT ,
52
+ "Float4" : onnx .TensorProto .FLOAT4E2M1 ,
53
53
"Float8" : onnx .TensorProto .FLOAT8E4M3FN ,
54
54
"Half" : onnx .TensorProto .FLOAT16 ,
55
55
"INT8" : onnx .TensorProto .INT8 ,
@@ -529,6 +529,11 @@ def _get_successive_consumers(
529
529
quantized_node = tensor_consumers .get (dq_node .output [0 ], [None ])[0 ]
530
530
if not quantized_node :
531
531
raise ValueError (f"No consumer found for { dq_node .name } " )
532
+ if quantized_node .op_type == "Cast" :
533
+ next_node = tensor_consumers .get (quantized_node .output [0 ], [None ])[0 ]
534
+ if not next_node :
535
+ raise ValueError (f"No consumer found after Cast for { quantized_node .name } " )
536
+ quantized_node = next_node
532
537
533
538
return dq_node , quantized_node
534
539
@@ -592,7 +597,7 @@ def _convert_weight(
592
597
zp_array = zp_array .reshape (* reshape_dims )
593
598
594
599
# Convert to INT8/FP8
595
- if zp_array .dtype == float8e4m3fn :
600
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
596
601
scaled = np .asarray (weight_array / scale_array ) + zp_array
597
602
else :
598
603
scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -607,17 +612,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607
612
if torch .cuda .is_available ():
608
613
array_f32_t = array_f32_t .cuda ()
609
614
array_f8_t = array_f32_t .clamp (min = - 448 , max = 448 ).to (torch .float8_e4m3fn ).view (torch .uint8 )
610
- array_f8 = array_f8_t .cpu ().numpy ().astype (( np .uint8 , [( "e4m3fn" , "u1" )]) )
615
+ array_f8 = array_f8_t .cpu ().numpy ().astype (np .uint8 )
611
616
return array_f8
612
617
613
618
614
619
def _cast_fp4 (array : np .ndarray ) -> np .ndarray :
615
- """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
620
+ """Cast a numpy array to FLOAT4E2M1 using PyTorch.
621
+
622
+ Note: The first dimension of the array must be divisible by 2
623
+ as two FP4 values are packed into a single byte.
624
+ """
616
625
array_f32_t = torch .from_numpy (array )
626
+ array_f32_t_shape = array_f32_t .shape
627
+ assert array_f32_t_shape [0 ] % 2 == 0 , "array_f32_t_shape[0] must be divisible by 2"
628
+ array_f4_t_shape = (array_f32_t_shape [0 ] // 2 , * array_f32_t_shape [1 :])
617
629
if torch .cuda .is_available ():
618
630
array_f32_t = array_f32_t .cuda ()
619
631
array_f4_t = NVFP4QTensor ._cast_fp4 (array_f32_t )
620
- array_f4 = array_f4_t .cpu ().numpy ().astype ((np .uint8 , [("float4e2m1" , "u1" )]))
632
+ array_f4_t = array_f4_t .flatten ()
633
+ array_f4_t_packed = (array_f4_t [::2 ] | (array_f4_t [1 ::2 ] << 4 )).reshape (array_f4_t_shape )
634
+ array_f4 = array_f4_t_packed .cpu ().numpy ().astype (np .uint8 )
621
635
return array_f4
622
636
623
637
@@ -685,7 +699,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685
699
scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
686
700
687
701
# Create and update new weight tensor
688
- if zp_array .dtype == float8e4m3fn :
702
+ if zp_array .dtype == onnx_dtype_map [ "Float8" ] :
689
703
new_weight = _create_fp8_tensor (scaled , weight_name )
690
704
logger .debug (f"Converted { weight_name } to FP8" )
691
705
else :
@@ -925,6 +939,10 @@ def quantize_weights_to_int4(
925
939
assert reshape_node .op_type == "Reshape" , f"Expected Reshape node for { node .name } "
926
940
reshape_node_output = reshape_node .output [0 ]
927
941
942
+ # Remove constant node from reshape node
943
+ shape_constant_name = next (input for input in reshape_node .input if "Constant" in input )
944
+ nodes_to_remove .append (tensor_producer_map [shape_constant_name ].name )
945
+
928
946
# Get the shape of the output of the reshape node
929
947
reshape_output_value_info = value_info_map .get (reshape_node_output )
930
948
if reshape_output_value_info is not None :
@@ -942,12 +960,17 @@ def quantize_weights_to_int4(
942
960
scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
943
961
scale = scale .reshape (scale_shape )
944
962
reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
945
- # reshape_node.input = []
946
963
assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
947
964
965
+ # Remove unnecessary Cast node
966
+ cast_node = reshape_child_nodes [0 ]
967
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
968
+ nodes_to_remove .append (cast_node .name )
969
+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
970
+
948
971
# Transpose weights and scales if present
949
- if reshape_child_nodes [0 ].op_type == "Transpose" :
950
- transpose_node = reshape_child_nodes [0 ]
972
+ if cast_child_nodes [0 ].op_type == "Transpose" :
973
+ transpose_node = cast_child_nodes [0 ]
951
974
nodes_to_remove .append (transpose_node .name )
952
975
assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
953
976
perm = None
@@ -964,7 +987,7 @@ def quantize_weights_to_int4(
964
987
)
965
988
matmul_node = transpose_child_nodes [0 ]
966
989
else :
967
- matmul_node = reshape_child_nodes [0 ]
990
+ matmul_node = cast_child_nodes [0 ]
968
991
assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
969
992
f"Expected MatMul or Gemm node for { node .name } "
970
993
)
@@ -995,9 +1018,24 @@ def quantize_weights_to_int4(
995
1018
initializer_map [weight_name ].CopyFrom (weights_int4_onnx )
996
1019
logger .debug (f"Converted { weight_name } to INT4 precision" )
997
1020
1021
+ def is_pre_quant_scale_node (node : onnx .NodeProto ) -> bool :
1022
+ has_pqs_input = any (input for input in node .input if "_pre_quant_scale" in input )
1023
+ return node .op_type == "Mul" and has_pqs_input
1024
+
1025
+ # Remove unnecessay Cast after Pre-quant scale
1026
+ for node in graph .node :
1027
+ if is_pre_quant_scale_node (node ):
1028
+ pqs_child_nodes = [n for n in graph .node if node .output [0 ] in n .input ]
1029
+ assert len (pqs_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
1030
+ cast_node = pqs_child_nodes [0 ]
1031
+ assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1032
+ node .output .clear ()
1033
+ node .output .extend (cast_node .output )
1034
+ nodes_to_remove .append (cast_node .name )
1035
+
998
1036
# Remove transpose and reshape nodes
999
1037
new_nodes = [node for node in graph .node if node .name not in nodes_to_remove ]
1000
- graph .node . clear ()
1038
+ del graph .node [:]
1001
1039
graph .node .extend (new_nodes )
1002
1040
1003
1041
def is_fp32_cast (node : onnx .NodeProto ) -> bool :
@@ -1009,7 +1047,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
1009
1047
for node in graph .node :
1010
1048
if node .op_type == "Cast" :
1011
1049
# Skip Cast nodes that are part of normalization layers and outputs
1012
- if ( "norm/Cast" in node .name and is_fp32_cast (node )) or node . name == "/Cast" :
1050
+ if "norm/Cast" in node .name and is_fp32_cast (node ):
1013
1051
continue
1014
1052
for attr in node .attribute :
1015
1053
if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
@@ -1104,7 +1142,13 @@ def quantize_weights_to_mxfp8(
1104
1142
# Expand block array so that it can be broadcasted with weight
1105
1143
se8m0_fp32 = np .repeat (se8m0_fp32 , block_size , axis = quant_axis )
1106
1144
scaled_weight = weight / np .exp2 (se8m0_fp32 - e8_m0_bias )
1107
- weights_e4m3 = onnx .numpy_helper .from_array (_cast_fp8 (scaled_weight ), weight_name )
1145
+ weights_e4m3 = onnx .helper .make_tensor (
1146
+ name = weight_name ,
1147
+ data_type = onnx_dtype_map ["Float8" ],
1148
+ dims = [* scaled_weight .shape ],
1149
+ vals = _cast_fp8 (scaled_weight ).tobytes (),
1150
+ raw = True ,
1151
+ )
1108
1152
initializer_map [weight_name ].CopyFrom (weights_e4m3 )
1109
1153
logger .debug (f"Converted { weight_name } to MXFP8" )
1110
1154
@@ -1186,11 +1230,24 @@ def _add_input_value_info(graph, tensor_proto):
1186
1230
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
1187
1231
1188
1232
# Create TensorProto for initializers
1189
- w_f4_proto = onnx .numpy_helper .from_array (w_f4 , w_f4_name )
1233
+ w_f4_proto = onnx .helper .make_tensor (
1234
+ name = w_f4_name ,
1235
+ data_type = onnx_dtype_map ["Float4" ],
1236
+ dims = [w_f4 .shape [0 ] * 2 , * w_f4 .shape [1 :]],
1237
+ vals = w_f4 .tobytes (),
1238
+ raw = True ,
1239
+ )
1190
1240
sw_f32_per_tensor_proto = onnx .numpy_helper .from_array (
1191
1241
sw_f32_per_tensor , sw_f32_per_tensor_name
1192
1242
)
1193
1243
sw_f8_per_block_proto = onnx .numpy_helper .from_array (sw_f8_per_block , sw_f8_per_block_name )
1244
+ sw_f8_per_block_proto = onnx .helper .make_tensor (
1245
+ name = sw_f8_per_block_name ,
1246
+ data_type = onnx_dtype_map ["Float8" ],
1247
+ dims = [* sw_f8_per_block .shape ],
1248
+ vals = sw_f8_per_block .tobytes (),
1249
+ raw = True ,
1250
+ )
1194
1251
1195
1252
# Add ValueInfo for the initializers if not present
1196
1253
_add_input_value_info (graph , w_f4_proto )
0 commit comments