Skip to content

Commit 1f8fde5

Browse files
committed
Refactor: simplified '_bypass_cast_node' function
Signed-off-by: gcunhase <[email protected]>
1 parent 6323d22 commit 1f8fde5

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -579,24 +579,23 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
579579

580580
input_tensor = node.input[0]
581581
output_tensor = node.output[0]
582-
is_output_producer = False
583582

584-
# If removed cast node is producing a network output, we need to update the node producing the cast
585-
# Network output name should not be changed
586-
for output in self.model.graph.output:
587-
if output.name == output_tensor:
588-
is_output_producer = True
589-
producers = utils.get_producer_nodes(self.model, input_tensor)
590-
for producer in producers:
591-
for i, prod_out in enumerate(producer.output):
592-
if prod_out == input_tensor:
593-
producer.output[i] = output_tensor
594-
consumers = utils.get_consumer_nodes(self.model, prod_out)
595-
if len(consumers) > 1:
596-
self._replace_tensor_name(consumers, prod_out, output_tensor)
597-
if (
598-
not is_output_producer
599-
): # Reconnect consumers of the cast output to use the cast input instead
583+
# Check if the cast output is also a graph output
584+
is_output_producer = any(output.name == output_tensor for output in self.model.graph.output)
585+
586+
# If the removed cast node is producing a network output, we need to update the node producing the cast, as
587+
# the network output name should not be changed
588+
if is_output_producer:
589+
producers = utils.get_producer_nodes(self.model, input_tensor)
590+
for producer in producers:
591+
for i, prod_out in enumerate(producer.output):
592+
if prod_out == input_tensor:
593+
producer.output[i] = output_tensor
594+
consumers = utils.get_consumer_nodes(self.model, prod_out)
595+
if len(consumers) > 1:
596+
self._replace_tensor_name(consumers, prod_out, output_tensor)
597+
else:
598+
# Reconnect consumers of the cast output to use the cast input instead
600599
consumers = utils.get_consumer_nodes(self.model, output_tensor)
601600
for consumer in consumers:
602601
for i, input_name in enumerate(consumer.input):

0 commit comments

Comments
 (0)