Skip to content

Commit e25ef84

Browse files
committed
Implement SDPA via MHA
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent d80575d commit e25ef84

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
fuse_rotary_embedding,
3030
)
3131
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
32+
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
3233
from onnxscript.rewriter.ort_fusions.skip_normalization import (
3334
fuse_skip_layer_normalization,
3435
fuse_skip_rms_normalization,
@@ -104,6 +105,7 @@ def fuse(func, **kwargs):
104105
fusion_count["attention"] = fuse(fuse_attention)
105106
fusion_count["gelu"] = fuse(fuse_gelu)
106107
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
108+
fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha)
107109
# Finally: inline any intermediate fusion functions introduced that were not
108110
# consumed by other fusions, and eliminate any remaining unused nodes.
109111
optimize(model)

0 commit comments

Comments
 (0)