Skip to content

Commit 1a8dbd7

Browse files
authored
Add a couple of variants of patterns in ORT fusions (#2077)
Add a couple of variants of patterns in ORT fusions (motivated by Phi4)
1 parent 1695ff3 commit 1a8dbd7

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def fuse_xformers(model: ir.Model) -> None:
3939

4040

4141
def optimize_for_ort(model: ir.Model) -> None:
42-
rewrite(model, ORT_PATTERN_REWRITE_RULES)
4342
fuse_xformers(model)
43+
rewrite(model, ORT_PATTERN_REWRITE_RULES)

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,15 @@ def rewrite(
136136
)
137137

138138

139-
_cast = CosSinCacheFusion.rule("CosSinCache", 2048, cast=True, const_freqs=True)
140-
_no_cast = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
141-
142-
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _no_cast])
139+
_cast_const_freqs = CosSinCacheFusion.rule(
140+
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
141+
)
142+
_cast = CosSinCacheFusion.rule(
143+
"CosSinCache_cast_no_const_freqs", 2048, cast=True, const_freqs=False
144+
)
145+
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
146+
147+
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _basic])
143148

144149
debug: bool = True
145150

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,30 @@
99

1010

1111
class SDPA(pattern.RewriteRuleClassBase):
12-
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool):
12+
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool):
1313
super().__init__(name=name)
1414
self._use_mask = use_mask
1515
self._pre_scale = pre_scale
16+
self._use_mul = use_mul
1617

1718
def pattern(
1819
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
1920
):
2021
if self._pre_scale:
2122
# Some implementations scale the query and key before computing the dot product
22-
query = op.Mul(query, query_scale)
23-
key_transposed = op.Mul(key_transposed, key_scale)
23+
if self._use_mul:
24+
query = op.Mul(query, query_scale)
25+
key_transposed = op.Mul(key_transposed, key_scale)
26+
else:
27+
query = op.Div(query, query_scale)
28+
key_transposed = op.Div(key_transposed, key_scale)
2429
attn_score = op.MatMul(query, key_transposed)
2530
if not self._pre_scale:
2631
# Some implementations scale the dot product.
27-
attn_score = op.Div(attn_score, qk_scale)
32+
if self._use_mul:
33+
attn_score = op.Mul(attn_score, qk_scale)
34+
else:
35+
attn_score = op.Div(attn_score, qk_scale)
2836
if self._use_mask:
2937
# Some implementations add a mask to the dot product.
3038
attn_score = op.Add(attn_score, mask)
@@ -42,16 +50,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
4250
if not isinstance(hidden_size, int):
4351
return False
4452
expected_scaling_factor = math.sqrt(hidden_size)
53+
if self._use_mul:
54+
expected_scaling_factor = 1.0 / expected_scaling_factor
4555

4656
if self._pre_scale:
47-
# Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(hidden_size))
48-
sqrt_scaling_factor = 1.0 / math.sqrt(expected_scaling_factor)
57+
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
58+
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
4959
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
5060
return False
5161
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
5262
return False
5363
else:
54-
# Check if qk_scale is a scalar == sqrt(hidden_size)
64+
# Check if qk_scale is a scalar == expected_scaling_factor)
5565
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
5666
return False
5767

@@ -63,13 +73,35 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
6373
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
6474

6575

66-
masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True)
67-
masked_post_div_sdpa_rule = SDPA.rule("masked_post_div_sdpa", use_mask=True, pre_scale=False)
76+
masked_pre_div_sdpa_rule = SDPA.rule(
77+
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=False
78+
)
79+
masked_pre_mul_sdpa_rule = SDPA.rule(
80+
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True
81+
)
82+
masked_post_div_sdpa_rule = SDPA.rule(
83+
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
84+
)
85+
masked_post_mul_sdpa_rule = SDPA.rule(
86+
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True
87+
)
6888

69-
sdpa_rules = pattern.RewriteRuleSet([masked_pre_mul_sdpa_rule, masked_post_div_sdpa_rule])
89+
sdpa_rules = pattern.RewriteRuleSet(
90+
[
91+
masked_pre_mul_sdpa_rule,
92+
masked_post_div_sdpa_rule,
93+
masked_post_mul_sdpa_rule,
94+
masked_pre_div_sdpa_rule,
95+
]
96+
)
97+
98+
debug: bool = True
7099

71100

72101
def fuse_sdpa(model: ir.Model) -> int:
73102
count = sdpa_rules.apply_to_model(model)
74-
print(f"SDPA count: {count}")
103+
if count == 0 and debug:
104+
sdpa_rules.apply_to_model(model, debug=True)
105+
else:
106+
print(f"SDPA count: {count}")
75107
return count

onnxscript/rewriter/pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1830,7 +1830,7 @@ def report(self) -> None:
18301830
print(f"Rule: {rule}")
18311831
print(f"Best score: {matches[0].score()}")
18321832
for match in matches:
1833-
print(f"Status: {match.status}")
1833+
print(f"Status: {match.status.name}")
18341834
if match.status == MatchStatus.NO_MATCH:
18351835
print("Graph matching failed: " + match.match_result.reason)
18361836
node = match.match_result._failure_node

0 commit comments

Comments
 (0)