Skip to content

Commit ce9045a

Browse files
committed
Inject identity nodes in sanitizer; revert existing logic; update test
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 03529fc commit ce9045a

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -332,20 +332,26 @@ def sanitize_io_casts(self) -> None:
332332
model_output_names = {output.name for output in self.model.graph.output}
333333
nodes_to_add = []
334334
for node in self.model.graph.node:
335-
if node.op_type == "Cast":
336-
if node.input[0] in model_input_names and node.output[0] in model_output_names:
337-
cast_input_name = node.input[0]
338-
cast_output_name = node.output[0]
339-
cast_new_output_name = cast_input_name + "_io_cast_identity"
340-
nodes_to_add.append(
341-
helper.make_node(
342-
"Identity",
343-
inputs=[cast_new_output_name],
344-
outputs=[cast_output_name],
345-
name=node.name + "_io_cast_identity",
346-
)
335+
if (
336+
node.op_type == "Cast"
337+
and node.input
338+
and node.output
339+
and node.input[0] in model_input_names
340+
and node.output[0] in model_output_names
341+
):
342+
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
343+
cast_output_name = node.output[0]
344+
cast_new_output_name = f"{cast_output_name}__io_cast_src"
345+
nodes_to_add.append(
346+
helper.make_node(
347+
"Identity",
348+
inputs=[cast_new_output_name],
349+
outputs=[cast_output_name],
350+
name=f"{node.name}__io_cast_identity",
347351
)
348-
node.output[0] = cast_new_output_name
352+
)
353+
# Rewire Cast to produce the new intermediate
354+
node.output[0] = cast_new_output_name
349355

350356
for node in nodes_to_add:
351357
self.model.graph.node.append(node)

0 commit comments

Comments
 (0)