Skip to content

Commit 05d11f6

Browse files
committed
Address PR feedback
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 347307a commit 05d11f6

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

onnxscript/rewriter/_pattern_ir.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ def __str__(self) -> str:
9696
return self._name if self._name is not None else "anonymous:" + str(id(self))
9797

9898

99-
AttrVar = AttrPattern
99+
class AttrVar(AttrPattern):
100+
"""Represents a pattern variable used to match against attribute values."""
101+
102+
def __init__(self, name: str | None, *, can_match_none: bool = False):
103+
super().__init__(name, can_match_none=can_match_none)
104+
100105

101106
# TODO: Support tensors. Align with usage elsewhere.
102107
SupportedAttrTypes = Union[
@@ -141,7 +146,7 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) ->
141146
raise ValueError(
142147
"Pattern variables used in attributes must not have check_method set."
143148
)
144-
return AttrPattern(value.name, can_match_none=value.can_match_none)
149+
return AttrVar(value.name, can_match_none=value.can_match_none)
145150
if isinstance(value, (int, float, str)):
146151
return AttrConstantPattern(value)
147152
if isinstance(value, Sequence):

onnxscript/rewriter/ort_fusions/mha_scale.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import math
6-
import onnx_ir as ir
76

87
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
98

@@ -17,20 +16,25 @@
1716
1817
Example pattern:
1918
query -> Mul(scale) -> MultiHeadAttention -> output
20-
19+
2120
Gets rewritten to:
2221
query -> MultiHeadAttention(with integrated scaling) -> output
2322
"""
2423

24+
2525
class FuseMHAScale(pattern.RewriteRuleClassBase):
2626
def pattern(self, op, query, scale):
2727
scaled_query = op.Mul(query, scale)
28-
mha_output = op.MultiHeadAttention(scaled_query, _allow_other_inputs=True,
29-
_domain="com.microsoft", _outputs=["mha_output"])
28+
mha_output = op.MultiHeadAttention(
29+
scaled_query,
30+
_allow_other_inputs=True,
31+
_domain="com.microsoft",
32+
_outputs=["mha_output"],
33+
)
3034
return mha_output
3135

3236
def check(self, context, scale, **_):
33-
scale_value =_ir_utils.get_singleton_value(scale)
37+
scale_value = _ir_utils.get_singleton_value(scale)
3438
if scale_value is None or not isinstance(scale_value, (int, float)):
3539
return pattern.MatchResult().fail("Scale must be a constant numeric value.", scale)
3640
self._scale = scale_value
@@ -54,8 +58,11 @@ def rewrite(self, op, query, mha_output, **_):
5458
inputs[0] = query
5559
attributes = dict(attributes)
5660
attributes["scale"] = self._scale
57-
return op.MultiHeadAttention(*inputs, **attributes, _domain="com.microsoft", _outputs=1)
61+
return op.MultiHeadAttention(
62+
*inputs, **attributes, _domain="com.microsoft", _outputs=1
63+
)
64+
5865

5966
_mha_scale_rules = pattern.RewriteRuleSet([FuseMHAScale.rule()])
6067

61-
fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules)
68+
fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules)

0 commit comments

Comments
 (0)