@@ -181,13 +181,12 @@ def make_value_info_from_tensor(tensor):
181181 "Upsample" ,
182182 # NEW:
183183 "RandomNormalLike" ,
184-
185184 # TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
186185 # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
187186 # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
188187 # However, without it, the graphs produced are invalid. Eventually, we will resolve this.
189188 "Cast" ,
190- ]
189+ ]
191190
192191
193192def initial_checking (model , disable_shape_infer ):
@@ -333,7 +332,7 @@ def process_graph_input(
333332 graph ,
334333 [graph_input .name ],
335334 [cast_node_output_name ],
336- cast_node_output_name , # Set node name same as output name
335+ cast_node_output_name , # Set node name same as output name
337336 FLOAT16 ,
338337 )
339338 add_new_value_info (
@@ -388,7 +387,9 @@ def process_graph_output(
388387 )
389388 for value_info in graph .value_info :
390389 if original_name == value_info .name :
391- value_info .type .tensor_type .elem_type = onnx_proto .TensorProto .FLOAT
390+ value_info .type .tensor_type .elem_type = (
391+ onnx_proto .TensorProto .FLOAT
392+ )
392393
393394 # Get the node(s) that consume the model output
394395 downstream_nodes = find_downstream_node_by_input_name (
@@ -545,7 +546,11 @@ def process_initializers(
545546 initializer_block_list = set ()
546547 for node in graph .node :
547548 if (node .op_type in op_block_list ) or (node .name in node_block_list ):
548- for input_name in node .input : # some is initializer, some is value_info, can't distinguish but doesn't matter
549+ for (
550+ input_name
551+ ) in (
552+ node .input
553+ ): # some is initializer, some is value_info, can't distinguish but doesn't matter
549554 initializer_block_list .add (input_name )
550555 # Process initializers
551556 for initializer in graph .initializer :
0 commit comments