Skip to content

Commit c33fce2

Browse files
authored
Add initial support for RotaryEmbedding fusion for onnx opset 23 (#2450)
Add initial support for RotaryEmbedding fusion for onnx opset 23 --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 061f62b commit c33fce2

20 files changed

+179
-15
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.

onnxscript/rewriter/ort_fusions/models/_rotary_embedding_models.py renamed to onnxscript/rewriter/models/_rotary_embedding_models.py

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import onnx_ir as ir
66

7-
from onnxscript.rewriter.onnx_fusions import _rms_normalization
7+
from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding
88

99

1010
def _get_onnx_opset_version(model: ir.Model) -> int | None:
@@ -23,6 +23,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
2323
"""Apply fusions targeting ONNX opset 23."""
2424
counts: dict[str, int] = {}
2525
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
26+
counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug)
2627
return counts
2728

2829

0 commit comments

Comments
 (0)