Skip to content

Commit 159b22e

Browse files
committed
Formatting
1 parent 3d8d5e3 commit 159b22e

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

scripts/float16.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,12 @@ def make_value_info_from_tensor(tensor):
181181
"Upsample",
182182
# NEW:
183183
"RandomNormalLike",
184-
185184
# TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
186185
# - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
187186
# - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
188187
# However, without it, the graphs produced are invalid. Eventually, we will resolve this.
189188
"Cast",
190-
]
189+
]
191190

192191

193192
def initial_checking(model, disable_shape_infer):
@@ -333,7 +332,7 @@ def process_graph_input(
333332
graph,
334333
[graph_input.name],
335334
[cast_node_output_name],
336-
cast_node_output_name, # Set node name same as output name
335+
cast_node_output_name, # Set node name same as output name
337336
FLOAT16,
338337
)
339338
add_new_value_info(
@@ -388,7 +387,9 @@ def process_graph_output(
388387
)
389388
for value_info in graph.value_info:
390389
if original_name == value_info.name:
391-
value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
390+
value_info.type.tensor_type.elem_type = (
391+
onnx_proto.TensorProto.FLOAT
392+
)
392393

393394
# Get the node(s) that consume the model output
394395
downstream_nodes = find_downstream_node_by_input_name(
@@ -545,7 +546,11 @@ def process_initializers(
545546
initializer_block_list = set()
546547
for node in graph.node:
547548
if (node.op_type in op_block_list) or (node.name in node_block_list):
548-
for input_name in node.input: # some is initializer, some is value_info, can't distinguish but doesn't matter
549+
for (
550+
input_name
551+
) in (
552+
node.input
553+
): # some is initializer, some is value_info, can't distinguish but doesn't matter
549554
initializer_block_list.add(input_name)
550555
# Process initializers
551556
for initializer in graph.initializer:

0 commit comments

Comments
 (0)