Skip to content

Commit 434baf6

Browse files
committed
Fix bypassing of 'Cast' connecting a consumer with multiple outputs and the model's output
Signed-off-by: gcunhase <[email protected]>
1 parent a6fa34c commit 434baf6

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,12 @@ def convert_initializer(
557557
to_type=self.high_precision_type,
558558
)
559559

560+
def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name):
561+
for consumer in consumers:
562+
for idx, inp in enumerate(consumer.input):
563+
if inp == original_tensor_name:
564+
consumer.input[idx] = new_tensor_name
565+
560566
def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
561567
# handling only a single input and output, as we only remove cast nodes
562568
assert len(node.input) == 1
@@ -576,6 +582,9 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
576582
for i, prod_out in enumerate(producer.output):
577583
if prod_out == input_tensor:
578584
producer.output[i] = output_tensor
585+
consumers = utils.get_consumer_nodes(self.model, prod_out)
586+
if len(consumers) > 1:
587+
self._replace_tensor_name(consumers, prod_out, output_tensor)
579588
if (
580589
not is_output_producer
581590
): # Reconnect consumers of the cast output to use the cast input instead

0 commit comments

Comments
 (0)