Skip to content

Commit e6c4551

Browse files
authored
Merge branch 'main' into justinchu/consolidate-index
2 parents a115fa9 + 726be2b commit e6c4551

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,31 @@ def test_constant_folding_creates_constant_nodes_in_function(self):
741741
constant_nodes = [n for n in func.graph if n.op_type == "Constant"]
742742
self.assertEqual(len(constant_nodes), 1)
743743

744+
def test_initializer_as_graph_output_is_not_removed(self):
745+
"""Test that an initializer that is a graph output is not removed during constant folding."""
746+
model = """
747+
<ir_version: 7, opset_import: [ "" : 17]>
748+
agraph (float[N] x) => (float[N] y, float z) {
749+
constant = Constant <value_float=2.0> ()
750+
y = Mul(x, constant)
751+
z = Identity(constant)
752+
}
753+
"""
754+
755+
optimized = self._fold(model)
756+
# After constant folding, the Identity node should be folded, and 'constant'
757+
# should become an initializer with the output name 'z'.
758+
# The key thing is that this initializer should NOT be removed even though
759+
# the Identity node was folded, because it is a graph output.
760+
self.assertIn("z", optimized.graph.initializers)
761+
# The Identity node should be removed
762+
identity_nodes = [n for n in optimized.graph if n.op_type == "Identity"]
763+
self.assertEqual(len(identity_nodes), 0)
764+
# Verify the graph still has both outputs
765+
output_names = [o.name for o in optimized.graph.outputs]
766+
self.assertIn("y", output_names)
767+
self.assertIn("z", output_names)
768+
744769

745770
if __name__ == "__main__":
746771
unittest.main()

0 commit comments

Comments
 (0)