Skip to content

Commit 3d8d5e3

Browse files
committed
fp16 conversion improvements
1 parent f7a088e commit 3d8d5e3

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

scripts/float16.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,14 @@ def make_value_info_from_tensor(tensor):
180180
"Max",
181181
"Upsample",
182182
# NEW:
183-
"Cast",
184183
"RandomNormalLike",
185-
]
184+
185+
# TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
186+
# - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
187+
# - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
188+
# However, without it, the graphs produced are invalid. Eventually, we will resolve this.
189+
"Cast",
190+
]
186191

187192

188193
def initial_checking(model, disable_shape_infer):
@@ -317,21 +322,26 @@ def process_graph_input(
317322
graph, graph_input.name
318323
)
319324
for d_node in downstream_nodes:
320-
cast_node_name = graph_input.name + "_cast_to_" + d_node.name
321-
cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name
322-
add_cast_node(
323-
graph,
324-
[graph_input.name],
325-
[cast_node_output_name],
326-
cast_node_name,
327-
FLOAT16,
328-
)
329-
add_new_value_info(
330-
graph,
331-
graph_input,
332-
cast_node_output_name,
333-
onnx_proto.TensorProto.FLOAT16,
334-
)
325+
# More than one node may consume the model input, so we only create
326+
# a single cast node, and then reuse this node when needed.
327+
cast_exists = graph_input.name in global_input_name_dict
328+
if cast_exists:
329+
cast_node_output_name = global_input_name_dict[graph_input.name]
330+
else:
331+
cast_node_output_name = graph_input.name + "_fp16"
332+
add_cast_node(
333+
graph,
334+
[graph_input.name],
335+
[cast_node_output_name],
336+
cast_node_output_name, # Set node name same as output name
337+
FLOAT16,
338+
)
339+
add_new_value_info(
340+
graph,
341+
graph_input,
342+
cast_node_output_name,
343+
onnx_proto.TensorProto.FLOAT16,
344+
)
335345
for i, input_name in enumerate(d_node.input):
336346
if input_name == graph_input.name:
337347
d_node.input[i] = (
@@ -378,9 +388,7 @@ def process_graph_output(
378388
)
379389
for value_info in graph.value_info:
380390
if original_name == value_info.name:
381-
value_info.type.tensor_type.elem_type = (
382-
onnx_proto.TensorProto.FLOAT
383-
)
391+
value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
384392

385393
# Get the node(s) that consume the model output
386394
downstream_nodes = find_downstream_node_by_input_name(
@@ -420,7 +428,7 @@ def process_node_in_block_list(
420428
def insert_cast32_before_node(
421429
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
422430
):
423-
for i in range(len(node.input)):
431+
for i, input_name in enumerate(node.input):
424432
input_name = node.input[i]
425433
for value_info in itertools.chain(graph.value_info, graph.input):
426434
if input_name == value_info.name:
@@ -449,7 +457,7 @@ def insert_cast32_before_node(
449457
def insert_cast16_after_node(
450458
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
451459
):
452-
for i in range(len(node.output)):
460+
for i, output_name in enumerate(node.output):
453461
output_name = node.output[i]
454462
for value_info in itertools.chain(graph.value_info, graph.output):
455463
if output_name == value_info.name:
@@ -537,11 +545,7 @@ def process_initializers(
537545
initializer_block_list = set()
538546
for node in graph.node:
539547
if (node.op_type in op_block_list) or (node.name in node_block_list):
540-
for (
541-
input_name
542-
) in (
543-
node.input
544-
): # some is initializer, some is value_info, can't distinguish but doesn't matter
548+
for input_name in node.input: # some is initializer, some is value_info, can't distinguish but doesn't matter
545549
initializer_block_list.add(input_name)
546550
# Process initializers
547551
for initializer in graph.initializer:
@@ -793,8 +797,8 @@ def get_type(name: str) -> Optional[int]:
793797
else:
794798
if (
795799
downstream_node.op_type == "Cast"
796-
and cast_node.attribute[0].i == 10
797-
and downstream_node.attribute[0].i == 1
800+
and cast_node.attribute[0].i == FLOAT16
801+
and downstream_node.attribute[0].i == FLOAT32
798802
and downstream_node in cast_node_list
799803
and cast_node in cast_node_list
800804
):

0 commit comments

Comments
 (0)