Skip to content

Commit 8f39bfc

Browse files
Copilotjustinchuby
andcommitted
Add test for initializer as graph output
Co-authored-by: justinchuby <[email protected]>
1 parent 351bc9c commit 8f39bfc

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,31 @@ def test_constant_folding_creates_constant_nodes_in_function(self):
741741
constant_nodes = [n for n in func.graph if n.op_type == "Constant"]
742742
self.assertEqual(len(constant_nodes), 1)
743743

744+
def test_initializer_as_graph_output_is_not_removed(self):
745+
"""Test that an initializer that is a graph output is not removed during constant folding."""
746+
model = """
747+
<ir_version: 7, opset_import: [ "" : 17]>
748+
agraph (float[N] x) => (float[N] y, float z) {
749+
constant = Constant <value_float=2.0> ()
750+
y = Mul(x, constant)
751+
z = Identity(constant)
752+
}
753+
"""
754+
755+
optimized = self._fold(model)
756+
# After constant folding, the Identity node should be folded, and 'constant'
757+
# should become an initializer with the output name 'z'.
758+
# The key thing is that this initializer should NOT be removed even though
759+
# the Identity node was folded, because it is a graph output.
760+
self.assertIn("z", optimized.graph.initializers)
761+
# The Identity node should be removed
762+
identity_nodes = [n for n in optimized.graph if n.op_type == "Identity"]
763+
self.assertEqual(len(identity_nodes), 0)
764+
# Verify the graph still has both outputs
765+
output_names = [o.name for o in optimized.graph.outputs]
766+
self.assertIn("y", output_names)
767+
self.assertIn("z", output_names)
768+
744769

745770
if __name__ == "__main__":
746771
unittest.main()

0 commit comments

Comments
 (0)