33from __future__ import annotations
44
55import math
6- import onnx_ir as ir
76
87from onnxscript .rewriter import _fusion_utils , _ir_utils , pattern
98
1716
1817Example pattern:
1918 query -> Mul(scale) -> MultiHeadAttention -> output
20-
19+
2120Gets rewritten to:
2221 query -> MultiHeadAttention(with integrated scaling) -> output
2322"""
2423
24+
2525class FuseMHAScale (pattern .RewriteRuleClassBase ):
2626 def pattern (self , op , query , scale ):
2727 scaled_query = op .Mul (query , scale )
28- mha_output = op .MultiHeadAttention (scaled_query , _allow_other_inputs = True ,
29- _domain = "com.microsoft" , _outputs = ["mha_output" ])
28+ mha_output = op .MultiHeadAttention (
29+ scaled_query ,
30+ _allow_other_inputs = True ,
31+ _domain = "com.microsoft" ,
32+ _outputs = ["mha_output" ],
33+ )
3034 return mha_output
3135
3236 def check (self , context , scale , ** _ ):
33- scale_value = _ir_utils .get_singleton_value (scale )
37+ scale_value = _ir_utils .get_singleton_value (scale )
3438 if scale_value is None or not isinstance (scale_value , (int , float )):
3539 return pattern .MatchResult ().fail ("Scale must be a constant numeric value." , scale )
3640 self ._scale = scale_value
@@ -54,8 +58,11 @@ def rewrite(self, op, query, mha_output, **_):
5458 inputs [0 ] = query
5559 attributes = dict (attributes )
5660 attributes ["scale" ] = self ._scale
57- return op .MultiHeadAttention (* inputs , ** attributes , _domain = "com.microsoft" , _outputs = 1 )
61+ return op .MultiHeadAttention (
62+ * inputs , ** attributes , _domain = "com.microsoft" , _outputs = 1
63+ )
64+
5865
5966_mha_scale_rules = pattern .RewriteRuleSet ([FuseMHAScale .rule ()])
6067
61- fuse_mha_scale = _fusion_utils .apply_fusion_rules (_mha_scale_rules )
68+ fuse_mha_scale = _fusion_utils .apply_fusion_rules (_mha_scale_rules )
0 commit comments