@@ -180,9 +180,14 @@ def make_value_info_from_tensor(tensor):
180180 "Max" ,
181181 "Upsample" ,
182182 # NEW:
183- "Cast" ,
184183 "RandomNormalLike" ,
185- ]
184+
185+ # TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
186+ # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
187+ # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
188+ # However, without it, the graphs produced are invalid. Eventually, we will resolve this.
189+ "Cast" ,
190+ ]
186191
187192
188193def initial_checking (model , disable_shape_infer ):
@@ -317,21 +322,26 @@ def process_graph_input(
317322 graph , graph_input .name
318323 )
319324 for d_node in downstream_nodes :
320- cast_node_name = graph_input .name + "_cast_to_" + d_node .name
321- cast_node_output_name = graph_input .name + "_cast_to_" + d_node .name
322- add_cast_node (
323- graph ,
324- [graph_input .name ],
325- [cast_node_output_name ],
326- cast_node_name ,
327- FLOAT16 ,
328- )
329- add_new_value_info (
330- graph ,
331- graph_input ,
332- cast_node_output_name ,
333- onnx_proto .TensorProto .FLOAT16 ,
334- )
325+ # More than one node may consume the model input, so we only create
326+ # a single cast node, and then reuse this node when needed.
327+ cast_exists = graph_input .name in global_input_name_dict
328+ if cast_exists :
329+ cast_node_output_name = global_input_name_dict [graph_input .name ]
330+ else :
331+ cast_node_output_name = graph_input .name + "_fp16"
332+ add_cast_node (
333+ graph ,
334+ [graph_input .name ],
335+ [cast_node_output_name ],
336+ cast_node_output_name , # Set node name same as output name
337+ FLOAT16 ,
338+ )
339+ add_new_value_info (
340+ graph ,
341+ graph_input ,
342+ cast_node_output_name ,
343+ onnx_proto .TensorProto .FLOAT16 ,
344+ )
335345 for i , input_name in enumerate (d_node .input ):
336346 if input_name == graph_input .name :
337347 d_node .input [i ] = (
@@ -378,9 +388,7 @@ def process_graph_output(
378388 )
379389 for value_info in graph .value_info :
380390 if original_name == value_info .name :
381- value_info .type .tensor_type .elem_type = (
382- onnx_proto .TensorProto .FLOAT
383- )
391+ value_info .type .tensor_type .elem_type = onnx_proto .TensorProto .FLOAT
384392
385393 # Get the node(s) that consume the model output
386394 downstream_nodes = find_downstream_node_by_input_name (
@@ -420,7 +428,7 @@ def process_node_in_block_list(
420428def insert_cast32_before_node (
421429 graph : onnx_proto .GraphProto , node : onnx_proto .NodeProto , global_input_name_dict
422430):
423- for i in range ( len ( node .input ) ):
431+ for i , input_name in enumerate ( node .input ):
424432 input_name = node .input [i ]
425433 for value_info in itertools .chain (graph .value_info , graph .input ):
426434 if input_name == value_info .name :
@@ -449,7 +457,7 @@ def insert_cast32_before_node(
449457def insert_cast16_after_node (
450458 graph : onnx_proto .GraphProto , node : onnx_proto .NodeProto , global_input_name_dict
451459):
452- for i in range ( len ( node .output ) ):
460+ for i , output_name in enumerate ( node .output ):
453461 output_name = node .output [i ]
454462 for value_info in itertools .chain (graph .value_info , graph .output ):
455463 if output_name == value_info .name :
@@ -537,11 +545,7 @@ def process_initializers(
537545 initializer_block_list = set ()
538546 for node in graph .node :
539547 if (node .op_type in op_block_list ) or (node .name in node_block_list ):
540- for (
541- input_name
542- ) in (
543- node .input
544- ): # some is initializer, some is value_info, can't distinguish but doesn't matter
548+ for input_name in node .input : # some is initializer, some is value_info, can't distinguish but doesn't matter
545549 initializer_block_list .add (input_name )
546550 # Process initializers
547551 for initializer in graph .initializer :
@@ -793,8 +797,8 @@ def get_type(name: str) -> Optional[int]:
793797 else :
794798 if (
795799 downstream_node .op_type == "Cast"
796- and cast_node .attribute [0 ].i == 10
797- and downstream_node .attribute [0 ].i == 1
800+ and cast_node .attribute [0 ].i == FLOAT16
801+ and downstream_node .attribute [0 ].i == FLOAT32
798802 and downstream_node in cast_node_list
799803 and cast_node in cast_node_list
800804 ):
0 commit comments