Skip to content

Commit e00358f

Browse files
committed
Add test case
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 3507d36 commit e00358f

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

onnxscript/converter_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,36 @@ def model(x):
710710
self.assertEqual(len(onnx_opset_import), 1)
711711
self.assertEqual(onnx_opset_import[0].version, 19)
712712

713+
def test_traced_if(self):
714+
"""Test that traced if statements are converted correctly."""
715+
716+
@script()
717+
def add_model(x: FLOAT[10]) -> FLOAT[10]:
718+
y = op.Add(x, x)
719+
return y
720+
721+
@script()
722+
def sub_model(x: FLOAT[10]) -> FLOAT[10]:
723+
y = op.Sub(x, x)
724+
return y
725+
726+
def make_model(flag: bool):
727+
@script()
728+
def model(x: FLOAT[10]) -> FLOAT[10]:
729+
if flag:
730+
y = op.Add(x, x)
731+
else:
732+
y = op.Sub(x, x)
733+
return y
734+
735+
return model.to_model_proto()
736+
737+
model_true = make_model(True)
738+
onnxscript.testing.assert_isomorphic(model_true, add_model.to_model_proto())
739+
740+
model_false = make_model(False)
741+
onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto())
742+
713743

714744
if __name__ == "__main__":
715745
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)