Skip to content

Commit 0b9afd9

Browse files
committed
Minor fix and test case
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent a770d98 commit 0b9afd9

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

onnxscript/rewriter/rules/fusion/_rotary_embedding.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,10 @@ def rewrite(self, op, x, freqs, **_):
6868
num_heads = x.shape[1]
6969
cos = op.Cos(freqs)
7070
sin = op.Sin(freqs)
71-
cos_4d = op.Unsqueeze(cos, 1)
72-
sin_4d = op.Unsqueeze(sin, 1)
7371
return op.RotaryEmbedding(
7472
x,
75-
cos_4d,
76-
sin_4d,
73+
cos,
74+
sin,
7775
interleaved=0,
7876
num_heads=num_heads,
7977
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnx_ir as ir
8+
from parameterized import parameterized
9+
10+
import onnxscript
11+
from onnxscript.rewriter import onnx_fusions
12+
from onnxscript.rewriter.models import _rotary_embedding_models
13+
import onnxscript.rewriter.testing
14+
15+
class RotaryEmbeddingOnnxFusionTest(unittest.TestCase):
16+
@parameterized.expand(
17+
[
18+
(
19+
"test_case_1",
20+
_rotary_embedding_models.test_case_1,
21+
),
22+
(
23+
"test_case_2",
24+
_rotary_embedding_models.test_case_2,
25+
),
26+
]
27+
)
28+
def test_rotary_embedding_fusion(self, _: str, test_data_constructor):
29+
test = test_data_constructor()
30+
for opset_version in [22, 23]:
31+
model: ir.Model = test.get_onnx_model()
32+
model.graph.opset_imports[""] = opset_version
33+
model_proto = ir.serde.serialize_model(model)
34+
onnxscript.optimizer.optimize(model)
35+
onnx_fusions.fuse(model)
36+
op_types = [n.op_type for n in model.graph]
37+
if opset_version == 22:
38+
self.assertNotIn("RotaryEmbedding", op_types)
39+
else:
40+
self.assertIn("RotaryEmbedding", op_types)
41+
rewritten_model_proto = ir.serde.serialize_model(model)
42+
inputs = test.get_ort_inputs()
43+
onnxscript.rewriter.testing.assert_numerically_equal(
44+
model_proto, rewritten_model_proto, args=inputs, use_reference=True
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main()

0 commit comments

Comments
 (0)