File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed
Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change @@ -710,6 +710,36 @@ def model(x):
710710 self .assertEqual (len (onnx_opset_import ), 1 )
711711 self .assertEqual (onnx_opset_import [0 ].version , 19 )
712712
713+ def test_traced_if (self ):
714+ """Test that traced if statements are converted correctly."""
715+
716+ @script ()
717+ def add_model (x : FLOAT [10 ]) -> FLOAT [10 ]:
718+ y = op .Add (x , x )
719+ return y
720+
721+ @script ()
722+ def sub_model (x : FLOAT [10 ]) -> FLOAT [10 ]:
723+ y = op .Sub (x , x )
724+ return y
725+
726+ def make_model (flag : bool ):
727+ @script ()
728+ def model (x : FLOAT [10 ]) -> FLOAT [10 ]:
729+ if flag :
730+ y = op .Add (x , x )
731+ else :
732+ y = op .Sub (x , x )
733+ return y
734+
735+ return model .to_model_proto ()
736+
737+ model_true = make_model (True )
738+ onnxscript .testing .assert_isomorphic (model_true , add_model .to_model_proto ())
739+
740+ model_false = make_model (False )
741+ onnxscript .testing .assert_isomorphic (model_false , sub_model .to_model_proto ())
742+
713743
714744if __name__ == "__main__" :
715745 unittest .main (verbosity = 2 )
You can’t perform that action at this time.
0 commit comments