@@ -1107,17 +1107,20 @@ def quantize_weights_to_int4(
11071107 scale_shape = [* weight_shape [:- 1 ], weight_shape [- 1 ] // block_size ]
11081108 scale = scale .reshape (scale_shape )
11091109 reshape_child_nodes = [n for n in graph .node if reshape_node .output [0 ] in n .input ]
1110- assert len (reshape_child_nodes ) == 1 , f"Expected exactly one transpose node for { node .name } "
1110+ assert len (reshape_child_nodes ) == 1 , f"Expected exactly one child node for { node .name } "
11111111
1112- # Remove unnecessary Cast node
1113- cast_node = reshape_child_nodes [0 ]
1114- assert cast_node .op_type == "Cast" , f"Expected Cast node for { node .name } "
1115- nodes_to_remove .append (cast_node .name )
1116- cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
1112+ # Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm
1113+ next_node = reshape_child_nodes [0 ]
1114+ if next_node .op_type == "Cast" :
1115+ # Remove unnecessary Cast node
1116+ cast_node = next_node
1117+ nodes_to_remove .append (cast_node .name )
1118+ cast_child_nodes = [n for n in graph .node if cast_node .output [0 ] in n .input ]
1119+ next_node = cast_child_nodes [0 ]
11171120
11181121 # Transpose weights and scales if present
1119- if cast_child_nodes [ 0 ] .op_type == "Transpose" :
1120- transpose_node = cast_child_nodes [ 0 ]
1122+ if next_node .op_type == "Transpose" :
1123+ transpose_node = next_node
11211124 nodes_to_remove .append (transpose_node .name )
11221125 assert transpose_node .op_type == "Transpose" , f"Expected Transpose node for { node .name } "
11231126 perm = None
@@ -1134,7 +1137,7 @@ def quantize_weights_to_int4(
11341137 )
11351138 matmul_node = transpose_child_nodes [0 ]
11361139 else :
1137- matmul_node = cast_child_nodes [ 0 ]
1140+ matmul_node = next_node
11381141 assert matmul_node .op_type in ["MatMul" , "Gemm" ], (
11391142 f"Expected MatMul or Gemm node for { node .name } "
11401143 )
@@ -1185,21 +1188,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
11851188 del graph .node [:]
11861189 graph .node .extend (new_nodes )
11871190
1188- def is_fp32_cast (node : onnx .NodeProto ) -> bool :
1189- return any (
1190- attr .name == "to" and attr .i == onnx .TensorProto .FLOAT for attr in node .attribute
1191- )
1192-
1193- # Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16)
1194- for node in graph .node :
1195- if node .op_type == "Cast" :
1196- # Skip Cast nodes that are part of normalization layers and outputs
1197- if "norm/Cast" in node .name and is_fp32_cast (node ):
1198- continue
1199- for attr in node .attribute :
1200- if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
1201- attr .i = onnx .TensorProto .FLOAT16
1202-
12031191 # Cast bias to float16
12041192 for node in graph .node :
12051193 if node .op_type == "Add" and "proj/Add" in node .name :
@@ -1306,13 +1294,6 @@ def quantize_weights_to_mxfp8(
13061294 if attr .name == "output_dtype" :
13071295 attr .i = onnx_dtype_map ["Half" ]
13081296
1309- # set Cast to FP16
1310- for node in graph .node :
1311- if node .op_type == "Cast" :
1312- for attr in node .attribute :
1313- if attr .name == "to" and attr .i == onnx .TensorProto .FLOAT :
1314- attr .i = onnx_dtype_map ["Half" ]
1315-
13161297 # Currently only tanh approximation is supported for Gelu
13171298 for node in gelu_nodes :
13181299 for attr in node .attribute :
0 commit comments