@@ -566,7 +566,10 @@ def convert_initializer(
566
566
to_type = self .high_precision_type ,
567
567
)
568
568
569
- def _replace_tensor_name (self , consumers , original_tensor_name , new_tensor_name ):
569
+ def _replace_tensor_name (
570
+ self , consumers : list [onnx .NodeProto ], original_tensor_name : str , new_tensor_name : str
571
+ ) -> None :
572
+ """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
570
573
for consumer in consumers :
571
574
for idx , inp in enumerate (consumer .input ):
572
575
if inp == original_tensor_name :
@@ -583,8 +586,8 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
583
586
# Check if the cast output is also a graph output
584
587
is_output_producer = any (output .name == output_tensor for output in self .model .graph .output )
585
588
586
- # If the removed cast node is producing a network output, we need to update the node producing the cast, as
587
- # the network output name should not be changed
589
+ # If the removed cast node is producing a network output, update the producer of the cast input so
590
+ # the network output name is preserved.
588
591
if is_output_producer :
589
592
producers = utils .get_producer_nodes (self .model , input_tensor )
590
593
for producer in producers :
0 commit comments