Skip to content

Commit caf9d39

Browse files
committed
move pass
Signed-off-by: Ali Boubezari <[email protected]>
1 parent ce9045a commit caf9d39

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def sanitize(self) -> None:
6363
self.ensure_graph_name_exists()
6464
onnx_utils.name_onnx_nodes(self.model.graph)
6565
self.replace_custom_domain_nodes()
66+
self.sanitize_io_casts()
6667
self.cleanup_model()
6768
self.set_ir_version(self.max_ir_version)
68-
self.sanitize_io_casts()
6969

7070
def find_custom_nodes(self) -> None:
7171
"""Find custom nodes in the model.
@@ -356,6 +356,10 @@ def sanitize_io_casts(self) -> None:
356356
for node in nodes_to_add:
357357
self.model.graph.node.append(node)
358358

359+
# Make sure the graph is topologically sorted
360+
gs_graph = gs.import_onnx(self.model).cleanup().toposort()
361+
self.model = gs.export_onnx(gs_graph)
362+
359363
def _create_layernorm_node(self, pattern: dict) -> onnx.NodeProto:
360364
"""Create a LayerNormalization node with optional bias."""
361365
ln_name = f"LayerNorm_{pattern['mean_node'].name}"

0 commit comments

Comments
 (0)