Skip to content

Commit 4ea72e3

Browse files
authored
[5504719] Fix bypassing of 'Cast' connecting a consumer with multiple outputs a… (#309)
Signed-off-by: gcunhase <[email protected]>
1 parent 4716131 commit 4ea72e3

File tree

2 files changed

+103
-14
lines changed

2 files changed

+103
-14
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -566,28 +566,39 @@ def convert_initializer(
566566
to_type=self.high_precision_type,
567567
)
568568

569+
def _replace_tensor_name(
570+
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
571+
) -> None:
572+
"""Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
573+
for consumer in consumers:
574+
for idx, inp in enumerate(consumer.input):
575+
if inp == original_tensor_name:
576+
consumer.input[idx] = new_tensor_name
577+
569578
def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
570579
# handling only a single input and output, as we only remove cast nodes
571580
assert len(node.input) == 1
572581
assert len(node.output) == 1
573582

574583
input_tensor = node.input[0]
575584
output_tensor = node.output[0]
576-
is_output_producer = False
577585

578-
# If removed cast node is producing a network output, we need to update the node producing the cast
579-
# Network output name should not be changed
580-
for output in self.model.graph.output:
581-
if output.name == output_tensor:
582-
is_output_producer = True
583-
producers = utils.get_producer_nodes(self.model, input_tensor)
584-
for producer in producers:
585-
for i, prod_out in enumerate(producer.output):
586-
if prod_out == input_tensor:
587-
producer.output[i] = output_tensor
588-
if (
589-
not is_output_producer
590-
): # Reconnect consumers of the cast output to use the cast input instead
586+
# Check if the cast output is also a graph output
587+
is_output_producer = any(output.name == output_tensor for output in self.model.graph.output)
588+
589+
# If the removed cast node is producing a network output, update the producer of the cast input so
590+
# the network output name is preserved.
591+
if is_output_producer:
592+
producers = utils.get_producer_nodes(self.model, input_tensor)
593+
for producer in producers:
594+
for i, prod_out in enumerate(producer.output):
595+
if prod_out == input_tensor:
596+
producer.output[i] = output_tensor
597+
consumers = utils.get_consumer_nodes(self.model, prod_out)
598+
if len(consumers) > 1:
599+
self._replace_tensor_name(consumers, prod_out, output_tensor)
600+
else:
601+
# Reconnect consumers of the cast output to use the cast input instead
591602
consumers = utils.get_consumer_nodes(self.model, output_tensor)
592603
for consumer in consumers:
593604
for i, input_name in enumerate(consumer.input):

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
10241024
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
10251025
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"
1026+
1027+
1028+
@pytest.fixture
1029+
def model_with_multiple_output_node_casted_to_output():
1030+
"""Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
1031+
# Create inputs and outputs
1032+
x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 2, 16, 16])
1033+
x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 16, 16])
1034+
x3 = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 4, 16, 16])
1035+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [1, 5, 16, 16])
1036+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [1, 9, 16, 16])
1037+
1038+
# Create computation nodes
1039+
concat_1_node = helper.make_node(
1040+
"Concat",
1041+
["X1", "X2"],
1042+
["concat_1_out"],
1043+
name="concat_1",
1044+
axis=1,
1045+
)
1046+
concat_2_node = helper.make_node(
1047+
"Concat",
1048+
["concat_1_out", "X3"],
1049+
["Y2"],
1050+
name="concat_2",
1051+
axis=1,
1052+
)
1053+
1054+
# Create a Cast node between 'concat_1' and the graph output
1055+
cast_node = helper.make_node(
1056+
"Cast",
1057+
["concat_1_out"],
1058+
["Y1"],
1059+
name="cast_0",
1060+
to=TensorProto.FLOAT,
1061+
)
1062+
1063+
graph = helper.make_graph(
1064+
[concat_1_node, concat_2_node, cast_node],
1065+
"model_with_multiple_output_node_casted_to_output",
1066+
[x1, x2, x3],
1067+
[y1, y2],
1068+
[],
1069+
)
1070+
1071+
model = helper.make_model(
1072+
graph, producer_name="model_with_multiple_output_node_casted_to_output"
1073+
)
1074+
model.opset_import[0].version = 20
1075+
model.ir_version = 10
1076+
onnx.checker.check_model(model)
1077+
1078+
model = onnx_utils.infer_shapes(model)
1079+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1080+
1081+
return model, value_info_map, initializer_map, node_to_init_map
1082+
1083+
1084+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1085+
def test_multiple_output_node_casted_to_output(
1086+
model_with_multiple_output_node_casted_to_output, low_precision_type
1087+
):
1088+
model, value_info_map, initializer_map, node_to_init_map = (
1089+
model_with_multiple_output_node_casted_to_output
1090+
)
1091+
1092+
converter = PrecisionConverter(
1093+
model,
1094+
value_info_map,
1095+
initializer_map,
1096+
node_to_init_map,
1097+
keep_io_types=True,
1098+
low_precision_type=low_precision_type,
1099+
)
1100+
converted_model = converter.convert(
1101+
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
1102+
)
1103+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)