Skip to content

Commit 42c122d

Browse files
committed
Fix bypassing of 'Cast' connecting a consumer with multiple outputs and the model's output
Signed-off-by: gcunhase <[email protected]>
1 parent 4716131 commit 42c122d

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,12 @@ def convert_initializer(
566566
to_type=self.high_precision_type,
567567
)
568568

569+
def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name):
570+
for consumer in consumers:
571+
for idx, inp in enumerate(consumer.input):
572+
if inp == original_tensor_name:
573+
consumer.input[idx] = new_tensor_name
574+
569575
def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
570576
# handling only a single input and output, as we only remove cast nodes
571577
assert len(node.input) == 1
@@ -585,6 +591,9 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
585591
for i, prod_out in enumerate(producer.output):
586592
if prod_out == input_tensor:
587593
producer.output[i] = output_tensor
594+
consumers = utils.get_consumer_nodes(self.model, prod_out)
595+
if len(consumers) > 1:
596+
self._replace_tensor_name(consumers, prod_out, output_tensor)
588597
if (
589598
not is_output_producer
590599
): # Reconnect consumers of the cast output to use the cast input instead

0 commit comments

Comments
 (0)