Skip to content

Commit 53d4a10

Browse files
committed
nit: comment and function types
Signed-off-by: gcunhase <[email protected]>
1 parent 1f8fde5 commit 53d4a10

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,10 @@ def convert_initializer(
566566
to_type=self.high_precision_type,
567567
)
568568

569-
def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name):
569+
def _replace_tensor_name(
570+
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
571+
) -> None:
572+
"""Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
570573
for consumer in consumers:
571574
for idx, inp in enumerate(consumer.input):
572575
if inp == original_tensor_name:
@@ -583,8 +586,8 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
583586
# Check if the cast output is also a graph output
584587
is_output_producer = any(output.name == output_tensor for output in self.model.graph.output)
585588

586-
# If the removed cast node is producing a network output, we need to update the node producing the cast, as
587-
# the network output name should not be changed
589+
# If the removed cast node is producing a network output, update the producer of the cast input so
590+
# the network output name is preserved.
588591
if is_output_producer:
589592
producers = utils.get_producer_nodes(self.model, input_tensor)
590593
for producer in producers:

0 commit comments

Comments
 (0)