Skip to content

Commit ac76dc7

Browse files
committed
address review comments
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 4028e2a commit ac76dc7

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,9 @@ def sanitize_io_casts(self) -> None:
351351
"""
352352
model_input_names = {input.name for input in self.model.graph.input}
353353
model_output_names = {output.name for output in self.model.graph.output}
354-
nodes_to_add = []
355-
for node in self.model.graph.node:
354+
insertions: list[tuple[int, onnx.NodeProto]] = []
355+
356+
for idx, node in enumerate(self.model.graph.node):
356357
if (
357358
node.op_type == "Cast"
358359
and node.input
@@ -363,23 +364,19 @@ def sanitize_io_casts(self) -> None:
363364
# Unique per graph output to avoid collisions when multiple outputs are cast from the same input
364365
cast_output_name = node.output[0]
365366
cast_new_output_name = f"{cast_output_name}__io_cast_src"
366-
nodes_to_add.append(
367-
helper.make_node(
368-
"Identity",
369-
inputs=[cast_new_output_name],
370-
outputs=[cast_output_name],
371-
name=f"{node.name}__io_cast_identity",
372-
)
367+
identity_node = helper.make_node(
368+
"Identity",
369+
inputs=[cast_new_output_name],
370+
outputs=[cast_output_name],
371+
name=f"{node.name}__io_cast_identity",
373372
)
374373
# Rewire Cast to produce the new intermediate
375374
node.output[0] = cast_new_output_name
375+
insertions.append((idx + 1, identity_node))
376376

377-
for node in nodes_to_add:
378-
self.model.graph.node.append(node)
379-
380-
# Make sure the graph is topologically sorted
381-
gs_graph = gs.import_onnx(self.model).cleanup().toposort()
382-
self.model = gs.export_onnx(gs_graph)
377+
# Insert Identities in-order right after their corresponding Casts
378+
for offset, (pos, id_node) in enumerate(insertions):
379+
self.model.graph.node.insert(pos + offset, id_node)
383380

384381
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
385382
"""Create a LayerNormalization node with optional bias."""

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
2929

30+
3031
def low_precision_onnx_type(low_precision_type_str):
3132
return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16
3233

@@ -1103,6 +1104,7 @@ def test_multiple_output_node_casted_to_output(
11031104
)
11041105
onnx.checker.check_model(converted_model)
11051106

1107+
11061108
@pytest.fixture
11071109
def model_with_casted_input_to_output():
11081110
"""Create a model with an output produced by a Cast node."""
@@ -1170,4 +1172,4 @@ def test_casted_input_to_output_model(
11701172
converted_model = converter.convert(
11711173
high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
11721174
)
1173-
onnx.checker.check_model(converted_model)
1175+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)