@@ -579,24 +579,23 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
579
579
580
580
input_tensor = node .input [0 ]
581
581
output_tensor = node .output [0 ]
582
- is_output_producer = False
583
582
584
- # If removed cast node is producing a network output, we need to update the node producing the cast
585
- # Network output name should not be changed
586
- for output in self . model . graph . output :
587
- if output . name == output_tensor :
588
- is_output_producer = True
589
- producers = utils . get_producer_nodes ( self . model , input_tensor )
590
- for producer in producers :
591
- for i , prod_out in enumerate ( producer . output ) :
592
- if prod_out == input_tensor :
593
- producer . output [ i ] = output_tensor
594
- consumers = utils . get_consumer_nodes ( self . model , prod_out )
595
- if len ( consumers ) > 1 :
596
- self . _replace_tensor_name (consumers , prod_out , output_tensor )
597
- if (
598
- not is_output_producer
599
- ): # Reconnect consumers of the cast output to use the cast input instead
583
+ # Check if the cast output is also a graph output
584
+ is_output_producer = any ( output . name == output_tensor for output in self . model . graph . output )
585
+
586
+ # If the removed cast node is producing a network output, we need to update the node producing the cast, as
587
+ # the network output name should not be changed
588
+ if is_output_producer :
589
+ producers = utils . get_producer_nodes ( self . model , input_tensor )
590
+ for producer in producers :
591
+ for i , prod_out in enumerate ( producer . output ) :
592
+ if prod_out == input_tensor :
593
+ producer . output [ i ] = output_tensor
594
+ consumers = utils . get_consumer_nodes ( self . model , prod_out )
595
+ if len (consumers ) > 1 :
596
+ self . _replace_tensor_name ( consumers , prod_out , output_tensor )
597
+ else :
598
+ # Reconnect consumers of the cast output to use the cast input instead
600
599
consumers = utils .get_consumer_nodes (self .model , output_tensor )
601
600
for consumer in consumers :
602
601
for i , input_name in enumerate (consumer .input ):
0 commit comments