Skip to content

Commit 2d8192a

Browse files
committed
Add tests for opset 22 also
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent a181003 commit 2d8192a

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,16 @@ def rms_norm_script(embedding, layernorm_weight):
5151
)
5252
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
5353
test = test_data_constructor()
54-
model: ir.Model = test.get_onnx_model()
55-
model.graph.opset_imports[""] = 23 # Test case is a valid opset 23 model
56-
onnxscript.optimizer.optimize(model)
57-
onnx_fusions.fuse(model)
58-
op_types = [n.op_type for n in model.graph]
59-
self.assertIn("RotaryEmbedding", op_types)
54+
for opset_version in [22, 23]:
55+
model: ir.Model = test.get_onnx_model()
56+
model.graph.opset_imports[""] = opset_version
57+
onnxscript.optimizer.optimize(model)
58+
onnx_fusions.fuse(model)
59+
op_types = [n.op_type for n in model.graph]
60+
if opset_version == 22:
61+
self.assertNotIn("RotaryEmbedding", op_types)
62+
else:
63+
self.assertIn("RotaryEmbedding", op_types)
6064

6165

6266
if __name__ == "__main__":

0 commit comments

Comments
 (0)