Skip to content

Commit d44853e

Browse files
authored
Attention fusion (part 2) (#2013)
Continuation of attention fusion. * Adds a version of GroupQueryAttention * Adds support in Cos-Sin cache fusion for constant-folded position-ids * Restructure MHA fusion into a class-based rewrite rule Also restructure the folder structure. * Eventually eliminate folders called "onnxruntime" and "transfomers", which hinder importing the original packages with those names. For now moving just the relevant new files. (Will restructure older files later.) * ORT-specific fusions go into the ort_fusions folder.
1 parent 6d2b530 commit d44853e

23 files changed

+471
-251
lines changed

.lintrunner.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ exclude_patterns = [
5050
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
53-
'onnxscript/rewriter/onnxruntime/xformers/_smollm_*.py', # onnxscript code
53+
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
5454
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
5555
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5656
'onnxscript/tools/function_unittest_producer.py', # FIXME

onnxscript/rewriter/generic_pattern.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,10 @@ def match(
549549
model: ir.Model,
550550
graph_or_function: ir.Graph | ir.Function,
551551
node: ir.Node,
552+
*,
552553
verbose: int = 0,
553554
remove_nodes: bool = True,
555+
tracer: orp.MatchingTracer | None = None,
554556
) -> orp.MatchResult | None:
555557
if not remove_nodes:
556558
raise NotImplementedError(

onnxscript/rewriter/onnxruntime/xformers/__init__.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

onnxscript/rewriter/onnxruntime/xformers/mha.py

Lines changed: 0 additions & 178 deletions
This file was deleted.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fusion optimizations for ORT backend."""
4+
5+
__all__ = [
6+
"optimize_for_ort",
7+
]
8+
9+
from onnxscript.rewriter.ort_fusions._core import optimize_for_ort
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import onnxscript.ir as ir
6+
from onnxscript.optimizer import optimize, remove_unused_nodes
7+
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
8+
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
9+
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
10+
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
11+
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
12+
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
13+
14+
15+
def fuse_xformers(model: ir.Model) -> None:
16+
optimize(model)
17+
fuse_rms_normalization(model)
18+
fuse_normalization(model)
19+
fuse_rotary_embedding(model)
20+
fuse_cos_sin_cache(model)
21+
fuse_sdpa(model)
22+
fuse_mha(model)
23+
remove_unused_nodes(model)
24+
25+
26+
def optimize_for_ort(model: ir.Model) -> None:
27+
# TODO(rama): Include the other optimizations
28+
fuse_xformers(model)
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)