Skip to content

Commit c219dce

Browse files
authored
MHA fusion cleanup (#2481)
* Cleanup MHA fusion rules by eliminating some redundant rule-variations * Fix handling of scale attribute in MHA fusion * Introduce can_match_none attribute-variables in patterns * Introduce fusion rule for MHA and scale --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent b042f5b commit c219dce

File tree

8 files changed

+110
-51
lines changed

8 files changed

+110
-51
lines changed

onnxscript/rewriter/_pattern_ir.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,33 @@ def __str__(self) -> str:
7676
class AttrPattern(Pattern[ir.Attr]):
7777
"""Base class for an attribute pattern. Matches any attribute value by default."""
7878

79-
def __init__(self, name: str | None):
79+
def __init__(self, name: str | None, *, can_match_none: bool = False):
8080
self._name = name
81+
self._can_match_none = can_match_none
8182

8283
@property
8384
def name(self) -> str | None:
8485
return self._name
8586

87+
@property
88+
def can_match_none(self) -> bool:
89+
"""Indicates whether this pattern can match a None attribute."""
90+
return self._can_match_none
91+
8692
def matches(self, attr: ir.Attr) -> bool:
8793
return True
8894

8995
def __str__(self) -> str:
9096
return self._name if self._name is not None else "anonymous:" + str(id(self))
9197

9298

99+
class AttrVar(AttrPattern):
100+
"""Represents a pattern variable used to match against attribute values."""
101+
102+
def __init__(self, name: str | None, *, can_match_none: bool = False):
103+
super().__init__(name, can_match_none=can_match_none)
104+
105+
93106
# TODO: Support tensors. Align with usage elsewhere.
94107
SupportedAttrTypes = Union[
95108
int,
@@ -129,11 +142,11 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) ->
129142
# annotations to distinguish between ValuePattern and AttrPattern, but forces users to
130143
# use these type annotations.
131144
# TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.)
132-
if value.can_match_none or value.check_method is not None:
145+
if value.check_method is not None:
133146
raise ValueError(
134-
"Pattern variables used in attributes must not have can_match_none or check_method set."
147+
"Pattern variables used in attributes must not have check_method set."
135148
)
136-
return AttrPattern(value.name)
149+
return AttrVar(value.name, can_match_none=value.can_match_none)
137150
if isinstance(value, (int, float, str)):
138151
return AttrConstantPattern(value)
139152
if isinstance(value, Sequence):
@@ -493,8 +506,9 @@ def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchRes
493506
for name, attr_pattern in self.attributes.items():
494507
attr_value = node.attributes.get(name)
495508
if attr_value is None:
496-
return match.fail(f"Attribute {name} not found in node.", node)
497-
if not attr_pattern.matches(attr_value):
509+
if not attr_pattern.can_match_none:
510+
return match.fail(f"Attribute {name} not found in node.", node)
511+
elif not attr_pattern.matches(attr_value):
498512
return match.fail(
499513
f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.",
500514
node,

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa
2323
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
2424
from onnxscript.rewriter.ort_fusions.mha_bias import fuse_mha_bias
25+
from onnxscript.rewriter.ort_fusions.mha_scale import fuse_mha_scale
2526
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
2627
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
2728
fuse_partial_rotary_embedding,
@@ -82,6 +83,7 @@ def fuse(func, **kwargs):
8283
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
8384
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
8485
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
86+
common_passes.CommonSubexpressionEliminationPass()(model)
8587
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
8688

8789
# We apply shape inference after the SDPA fusion as new nodes are added
@@ -90,9 +92,9 @@ def fuse(func, **kwargs):
9092

9193
fusion_count["gqa"] = fuse(fuse_gqa)
9294
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
93-
9495
fusion_count["mha1"] = fuse(fuse_mha1)
9596
fusion_count["mha2"] = fuse(fuse_mha2)
97+
fusion_count["mha_scale"] = fuse(fuse_mha_scale)
9698
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
9799
fusion_count["mha_bias"] = 0
98100
fusion_count["attention"] = 0

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def pattern(
111111
num_heads=num_heads,
112112
# scale=scale,
113113
_domain="com.microsoft",
114-
_outputs=3,
114+
_outputs=["mha_output", "present_key", "present_value"],
115115
)
116116
# Concat present_key and present_value to form present
117117
present_key = op.Unsqueeze(present_key, [0])
@@ -132,7 +132,7 @@ def pattern(
132132
num_heads=num_heads,
133133
# scale=scale,
134134
_domain="com.microsoft",
135-
_outputs=1,
135+
_outputs=["mha_output"],
136136
)
137137
return attention
138138

@@ -260,6 +260,7 @@ def rewrite(
260260
attention_bias,
261261
num_heads,
262262
# scale,
263+
mha_output,
263264
q_mul=None,
264265
k_mul=None,
265266
v_mul=None,
@@ -274,6 +275,8 @@ def rewrite(
274275
if self._no_slice:
275276
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1)
276277

278+
scale = mha_output.producer().attributes.get_float("scale", None)
279+
277280
if self._has_past:
278281
attention, present = op.Attention(
279282
input,
@@ -285,7 +288,7 @@ def rewrite(
285288
# past_sequence_length
286289
num_heads=num_heads,
287290
qkv_hidden_sizes=qkv_hidden_sizes,
288-
# scale=scale,
291+
scale=scale,
289292
_domain="com.microsoft",
290293
_outputs=2,
291294
)
@@ -302,7 +305,7 @@ def rewrite(
302305
None, # past_sequence_length
303306
num_heads=num_heads,
304307
qkv_hidden_sizes=qkv_hidden_sizes,
305-
# scale=scale,
308+
scale=scale,
306309
_domain="com.microsoft",
307310
_outputs=1,
308311
)

onnxscript/rewriter/ort_fusions/attention_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def test_whisper_encoder(self):
176176
mha_count = xformers.fuse_mha1(model)
177177
mha_count += xformers.fuse_mha2(model)
178178
self.assertGreater(mha_count, 0)
179+
mha_scale_count = xformers.fuse_mha_scale(model)
180+
self.assertGreater(mha_scale_count, 0)
179181
fused_mha_bias_count = xformers.fuse_mha_bias(model)
180182
self.assertGreater(fused_mha_bias_count, 0)
181183
# TODO: Enable once source of discrepancy is found

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,12 @@ def __init__(
3838
name,
3939
*,
4040
double_transpose: bool,
41-
transpose_4d: bool,
42-
pre_scale_q: bool,
4341
is_rotary: bool,
4442
has_past_present: bool,
4543
is_cross_attention: bool,
4644
):
4745
super().__init__(name)
4846
self._double_transpose = double_transpose
49-
self._transpose_4d = transpose_4d
50-
self._pre_scale_q = pre_scale_q
5147
self._is_rotary = is_rotary
5248
self._has_past_present = has_past_present
5349
self._is_cross_attention = is_cross_attention
@@ -63,12 +59,9 @@ def pattern(
6359
position_ids,
6460
cos,
6561
sin,
66-
q_scale,
6762
):
6863
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
6964

70-
if self._pre_scale_q:
71-
query_BSD = op.Mul(query_BSD, q_scale)
7265
# Reshape from (B, S, D) to (B, S, H, D/H)
7366
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
7467
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
@@ -93,24 +86,12 @@ def pattern(
9386
value_BHSDh = value
9487

9588
if self._is_rotary:
96-
# This is workaround for examples where there is a duplication of Unsqueeze op
97-
# to generate a 2D positions-ids from a 1D position-ids. This can be eliminated
98-
# if we have CSE-optimization to eliminate the duplicate Unsqueeze ops.
99-
# For now, same flag (transpose_4d) controls this variation. A different flag
100-
# can be added if we see instances that mix the two.
101-
if self._transpose_4d:
102-
position_ids_q = op.Unsqueeze(position_ids, [0])
103-
position_ids_k = op.Unsqueeze(position_ids, [0])
104-
else:
105-
position_ids_q = position_ids
106-
position_ids_k = position_ids
107-
10889
query_BHSDh_emb = op.RotaryEmbedding(
109-
query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft"
90+
query_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
11091
)
11192
if not self._is_cross_attention:
11293
key_BHSDh_emb = op.RotaryEmbedding(
113-
key, position_ids_k, cos, sin, _domain="com.microsoft"
94+
key, position_ids, cos, sin, _domain="com.microsoft"
11495
)
11596
else:
11697
key_BHSDh_emb = key
@@ -289,6 +270,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
289270
else:
290271
self._use_mask_broadcast = False
291272

273+
self._scale = sdpa_node.attributes.get_float("scale", None)
292274
# TODO: verify Reshapes:
293275
# eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
294276
# and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
@@ -307,20 +289,14 @@ def rewrite(
307289
position_ids,
308290
cos,
309291
sin,
310-
q_scale=None,
311292
**_,
312293
):
313-
scale = _ir_utils.get_singleton_value(q_scale)
314294
num_heads = _ir_utils.get_dim(query_BSHDh, 2)
315295
if not isinstance(num_heads, int):
316296
return None
317297

318298
# TODO: forward other attributes
319299

320-
if self._transpose_4d:
321-
zero_1d = op.Constant(value_ints=[0])
322-
position_ids = op.Unsqueeze(position_ids, zero_1d)
323-
324300
if self._is_rotary:
325301
query_BSD_emb = op.RotaryEmbedding(
326302
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
@@ -360,27 +336,21 @@ def rewrite(
360336
past_key,
361337
past_value,
362338
num_heads=num_heads,
363-
scale=scale,
364339
_domain="com.microsoft",
365340
_outputs=num_outputs,
341+
scale=self._scale,
366342
)
367343

368344

369345
def _make_rule_set(has_past_present: bool):
370346
parameter_combinations = [
371347
{
372348
"double_transpose": double_transpose,
373-
"transpose_4d": transpose_4d,
374-
"pre_scale_q": pre_scale_q,
375349
"is_rotary": is_rotary,
376350
"has_past_present": has_past_present,
377351
"is_cross_attention": is_cross_attention,
378352
}
379353
for double_transpose in [False, True]
380-
for transpose_4d in (
381-
[False, True] if double_transpose else [False]
382-
) # Only generate patterns when double_transpose is True
383-
for pre_scale_q in [True, False]
384354
for is_rotary in [False, True]
385355
for is_cross_attention in ([False] if has_past_present else [False, True])
386356
]
@@ -389,9 +359,8 @@ def _make_rule_set(has_past_present: bool):
389359
mha_rules = pattern.RewriteRuleSet(
390360
[
391361
MultiHeadAttention.rule(
392-
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
362+
f"MHA"
393363
f"{'_Twice' if params['double_transpose'] else ''}"
394-
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
395364
f"{'_Rotary' if params['is_rotary'] else ''}"
396365
f"{'_Past' if params['has_past_present'] else ''}"
397366
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",

onnxscript/rewriter/ort_fusions/mha_bias.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def pattern(
2828
past_key,
2929
past_value,
3030
num_heads,
31-
# scale,
3231
):
3332
query_BSD = pattern.OrValue(
3433
[op.Add(query_matmul, q_bias), query_matmul],
@@ -56,7 +55,7 @@ def pattern(
5655
pattern.Var("past_key", can_match_none=True),
5756
pattern.Var("past_value", can_match_none=True),
5857
num_heads=num_heads,
59-
# scale=scale,
58+
scale=pattern.AttrVar("scale", can_match_none=True),
6059
_domain="com.microsoft",
6160
)
6261

@@ -132,7 +131,7 @@ def rewrite(
132131
past_key,
133132
past_value,
134133
num_heads,
135-
# scale,
134+
scale,
136135
**_,
137136
):
138137
if q_bias is None:
@@ -158,7 +157,7 @@ def rewrite(
158157
past_key,
159158
past_value,
160159
num_heads=num_heads,
161-
# scale=scale,
160+
scale=scale,
162161
_domain="com.microsoft",
163162
)
164163

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import math
6+
7+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
8+
9+
"""
10+
Multi-Head Attention (MHA) pre-scaling fusion patterns.
11+
12+
This module contains rewrite rules for fusing scale operations that occur before
13+
Multi-Head Attention operations. The fusion optimizes patterns where a query tensor
14+
is scaled before being passed to MHA by incorporating the scaling directly into
15+
the MHA operation.
16+
17+
Example pattern:
18+
query -> Mul(scale) -> MultiHeadAttention -> output
19+
20+
Gets rewritten to:
21+
query -> MultiHeadAttention(with integrated scaling) -> output
22+
"""
23+
24+
25+
class FuseMHAScale(pattern.RewriteRuleClassBase):
26+
def pattern(self, op, query, scale):
27+
scaled_query = op.Mul(query, scale)
28+
mha_output = op.MultiHeadAttention(
29+
scaled_query,
30+
_allow_other_inputs=True,
31+
_domain="com.microsoft",
32+
_outputs=["mha_output"],
33+
)
34+
return mha_output
35+
36+
def check(self, context, scale, **_):
37+
scale_value = _ir_utils.get_singleton_value(scale)
38+
if scale_value is None or not isinstance(scale_value, (int, float)):
39+
return pattern.MatchResult().fail("Scale must be a constant numeric value.", scale)
40+
self._scale = scale_value
41+
return True
42+
43+
def rewrite(self, op, query, mha_output, **_):
44+
# Integrate the scale into the MHA operation
45+
mha_node = mha_output.producer()
46+
assert mha_node is not None
47+
# Compute original scale factor for MHA:
48+
attributes = mha_node.attributes
49+
original_scale = attributes.get_float("scale", None)
50+
if original_scale is None:
51+
num_heads = attributes.get_int("num_heads", None)
52+
if num_heads is None:
53+
return None
54+
head_size = query.shape[-1] // num_heads
55+
original_scale = 1.0 / math.sqrt(head_size)
56+
self._scale *= original_scale
57+
inputs = list(mha_node.inputs)
58+
inputs[0] = query
59+
attributes = dict(attributes)
60+
attributes["scale"] = self._scale
61+
return op.MultiHeadAttention(
62+
*inputs, **attributes, _domain="com.microsoft", _outputs=1
63+
)
64+
65+
66+
_mha_scale_rules = pattern.RewriteRuleSet([FuseMHAScale.rule()])
67+
68+
fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules)

onnxscript/rewriter/pattern.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnxscript.rewriter._matcher import PatternMatcher, SimplePatternMatcher
88
from onnxscript.rewriter._pattern_ir import (
99
ANY_VALUE,
10+
AttrVar,
1011
Constant,
1112
OpsetPatternBuilder,
1213
OrValue,
@@ -26,6 +27,7 @@
2627

2728
__all__ = [
2829
"ANY_VALUE",
30+
"AttrVar",
2931
"OrValue",
3032
"Constant",
3133
"OpsetPatternBuilder",

0 commit comments

Comments
 (0)