File tree Expand file tree Collapse file tree 1 file changed +10
-6
lines changed
onnxscript/rewriter/onnx_fusions Expand file tree Collapse file tree 1 file changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -51,12 +51,16 @@ def rms_norm_script(embedding, layernorm_weight):
5151 )
5252 def test_rotary_embedding_fusion (self , _ : str , test_data_constructor ):
5353 test = test_data_constructor ()
54- model : ir .Model = test .get_onnx_model ()
55- model .graph .opset_imports ["" ] = 23 # Test case is a valid opset 23 model
56- onnxscript .optimizer .optimize (model )
57- onnx_fusions .fuse (model )
58- op_types = [n .op_type for n in model .graph ]
59- self .assertIn ("RotaryEmbedding" , op_types )
54+ for opset_version in [22 , 23 ]:
55+ model : ir .Model = test .get_onnx_model ()
56+ model .graph .opset_imports ["" ] = opset_version
57+ onnxscript .optimizer .optimize (model )
58+ onnx_fusions .fuse (model )
59+ op_types = [n .op_type for n in model .graph ]
60+ if opset_version == 22 :
61+ self .assertNotIn ("RotaryEmbedding" , op_types )
62+ else :
63+ self .assertIn ("RotaryEmbedding" , op_types )
6064
6165
6266if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments