Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,28 +566,39 @@ def convert_initializer(
to_type=self.high_precision_type,
)

def _replace_tensor_name(
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
) -> None:
"""Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
for consumer in consumers:
for idx, inp in enumerate(consumer.input):
if inp == original_tensor_name:
consumer.input[idx] = new_tensor_name

def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
# handling only a single input and output, as we only remove cast nodes
assert len(node.input) == 1
assert len(node.output) == 1

input_tensor = node.input[0]
output_tensor = node.output[0]
is_output_producer = False

# If removed cast node is producing a network output, we need to update the node producing the cast
# Network output name should not be changed
for output in self.model.graph.output:
if output.name == output_tensor:
is_output_producer = True
producers = utils.get_producer_nodes(self.model, input_tensor)
for producer in producers:
for i, prod_out in enumerate(producer.output):
if prod_out == input_tensor:
producer.output[i] = output_tensor
if (
not is_output_producer
): # Reconnect consumers of the cast output to use the cast input instead
# Check if the cast output is also a graph output
is_output_producer = any(output.name == output_tensor for output in self.model.graph.output)

# If the removed cast node is producing a network output, update the producer of the cast input so
# the network output name is preserved.
if is_output_producer:
producers = utils.get_producer_nodes(self.model, input_tensor)
for producer in producers:
for i, prod_out in enumerate(producer.output):
if prod_out == input_tensor:
producer.output[i] = output_tensor
consumers = utils.get_consumer_nodes(self.model, prod_out)
if len(consumers) > 1:
self._replace_tensor_name(consumers, prod_out, output_tensor)
else:
# Reconnect consumers of the cast output to use the cast input instead
consumers = utils.get_consumer_nodes(self.model, output_tensor)
for consumer in consumers:
for i, input_name in enumerate(consumer.input):
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"


@pytest.fixture
def model_with_multiple_output_node_casted_to_output():
"""Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
# Create inputs and outputs
x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 2, 16, 16])
x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 16, 16])
x3 = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 4, 16, 16])
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [1, 5, 16, 16])
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [1, 9, 16, 16])

# Create computation nodes
concat_1_node = helper.make_node(
"Concat",
["X1", "X2"],
["concat_1_out"],
name="concat_1",
axis=1,
)
concat_2_node = helper.make_node(
"Concat",
["concat_1_out", "X3"],
["Y2"],
name="concat_2",
axis=1,
)

# Create a Cast node between 'concat_1' and the graph output
cast_node = helper.make_node(
"Cast",
["concat_1_out"],
["Y1"],
name="cast_0",
to=TensorProto.FLOAT,
)

graph = helper.make_graph(
[concat_1_node, concat_2_node, cast_node],
"model_with_multiple_output_node_casted_to_output",
[x1, x2, x3],
[y1, y2],
[],
)

model = helper.make_model(
graph, producer_name="model_with_multiple_output_node_casted_to_output"
)
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_multiple_output_node_casted_to_output(
model_with_multiple_output_node_casted_to_output, low_precision_type
):
model, value_info_map, initializer_map, node_to_init_map = (
model_with_multiple_output_node_casted_to_output
)

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
low_precision_type=low_precision_type,
)
converted_model = converter.convert(
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
)
onnx.checker.check_model(converted_model)
Loading