@@ -1111,17 +1111,20 @@ def quantize_weights_to_int4(
11111111 scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
11121112 scale = scale .reshape (scale_shape )
11131113 reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
1114- assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
1114+ assert len (reshape_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
11151115
1116- # Remove unnecessary Cast node
1117- cast_node = reshape_child_nodes [0 ]
1118- assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1119- nodes_to_remove .append (cast_node .name )
1120- cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
1116+ # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm
1117+ next_node = reshape_child_nodes [0 ]
1118+ if next_node .op_type == "Cast" :
1119+ # Remove unnecessary Cast node
1120+ cast_node = next_node
1121+ nodes_to_remove .append (cast_node .name )
1122+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
1123+ next_node = cast_child_nodes [0 ]
11211124
11221125 # Transpose weights and scales if present
1123- if cast_child_nodes [ 0 ] .op_type == "Transpose" :
1124- transpose_node = cast_child_nodes [ 0 ]
1126+ if next_node .op_type == "Transpose" :
1127+ transpose_node = next_node
11251128 nodes_to_remove .append (transpose_node .name )
11261129 assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
11271130 perm = None
@@ -1138,7 +1141,7 @@ def quantize_weights_to_int4(
11381141 )
11391142 matmul_node = transpose_child_nodes [0 ]
11401143 else :
1141- matmul_node = cast_child_nodes [ 0 ]
1144+ matmul_node = next_node
11421145 assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
11431146 f"Expected MatMul or Gemm node for { node .name } "
11441147 )
@@ -1189,21 +1192,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
11891192 del graph .node [:]
11901193 graph .node .extend (new_nodes )
11911194
1192- def is_fp32_cast (node : onnx .NodeProto ) -> bool :
1193- return any (
1194- attr .name == "to" and attr .i == onnx .TensorProto .FLOAT for attr in node .attribute
1195- )
1196-
1197- # Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16)
1198- for node in graph .node :
1199- if node .op_type == "Cast" :
1200- # Skip Cast nodes that are part of normalization layers and outputs
1201- if "norm/Cast" in node .name and is_fp32_cast (node ):
1202- continue
1203- for attr in node .attribute :
1204- if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
1205- attr .i = onnx .TensorProto .FLOAT16
1206-
12071195 # Cast bias to float16
12081196 for node in graph .node :
12091197 if node .op_type == "Add" and "proj/Add" in node .name :
@@ -1310,13 +1298,6 @@ def quantize_weights_to_mxfp8(
13101298 if attr .name == "output_dtype" :
13111299 attr .i = onnx_dtype_map ["Half" ]
13121300
1313- # set Cast to FP16
1314- for node in graph .node :
1315- if node .op_type == "Cast" :
1316- for attr in node .attribute :
1317- if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
1318- attr .i = onnx_dtype_map ["Half" ]
1319-
13201301 # Currently only tanh approximation is supported for Gelu
13211302 for node in gelu_nodes :
13221303 for attr in node .attribute :
0 commit comments