99
1010
1111class SDPA (pattern .RewriteRuleClassBase ):
12- def __init__ (self , name : str , * , use_mask : bool , pre_scale : bool ):
12+ def __init__ (self , name : str , * , use_mask : bool , pre_scale : bool , use_mul : bool ):
1313 super ().__init__ (name = name )
1414 self ._use_mask = use_mask
1515 self ._pre_scale = pre_scale
16+ self ._use_mul = use_mul
1617
1718 def pattern (
1819 self , op , query , key_transposed , value , mask , query_scale , key_scale , qk_scale
1920 ):
2021 if self ._pre_scale :
2122 # Some implementations scale the query and key before computing the dot product
22- query = op .Mul (query , query_scale )
23- key_transposed = op .Mul (key_transposed , key_scale )
23+ if self ._use_mul :
24+ query = op .Mul (query , query_scale )
25+ key_transposed = op .Mul (key_transposed , key_scale )
26+ else :
27+ query = op .Div (query , query_scale )
28+ key_transposed = op .Div (key_transposed , key_scale )
2429 attn_score = op .MatMul (query , key_transposed )
2530 if not self ._pre_scale :
2631 # Some implementations scale the dot product.
27- attn_score = op .Div (attn_score , qk_scale )
32+ if self ._use_mul :
33+ attn_score = op .Mul (attn_score , qk_scale )
34+ else :
35+ attn_score = op .Div (attn_score , qk_scale )
2836 if self ._use_mask :
2937 # Some implementations add a mask to the dot product.
3038 attn_score = op .Add (attn_score , mask )
@@ -42,16 +50,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
4250 if not isinstance (hidden_size , int ):
4351 return False
4452 expected_scaling_factor = math .sqrt (hidden_size )
53+ if self ._use_mul :
54+ expected_scaling_factor = 1.0 / expected_scaling_factor
4555
4656 if self ._pre_scale :
47- # Check if query_scale and key_scale are scalars == 1/ sqrt(sqrt(hidden_size) )
48- sqrt_scaling_factor = 1.0 / math .sqrt (expected_scaling_factor )
57+ # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor )
58+ sqrt_scaling_factor = math .sqrt (expected_scaling_factor )
4959 if not _ir_utils .is_singleton_value (query_scale , sqrt_scaling_factor , rtol = 1e-3 ):
5060 return False
5161 if not _ir_utils .is_singleton_value (key_scale , sqrt_scaling_factor , rtol = 1e-3 ):
5262 return False
5363 else :
54- # Check if qk_scale is a scalar == sqrt(hidden_size )
64+ # Check if qk_scale is a scalar == expected_scaling_factor )
5565 if not _ir_utils .is_singleton_value (qk_scale , expected_scaling_factor , rtol = 1e-3 ):
5666 return False
5767
@@ -63,13 +73,35 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
6373 return op .SDPA (query , key_transposed , value , mask , _domain = "ai.onnxruntime.fusion" )
6474
6575
66- masked_pre_mul_sdpa_rule = SDPA .rule ("masked_pre_mul_sdpa" , use_mask = True , pre_scale = True )
67- masked_post_div_sdpa_rule = SDPA .rule ("masked_post_div_sdpa" , use_mask = True , pre_scale = False )
76+ masked_pre_div_sdpa_rule = SDPA .rule (
77+ "masked_pre_mul_sdpa" , use_mask = True , pre_scale = True , use_mul = False
78+ )
79+ masked_pre_mul_sdpa_rule = SDPA .rule (
80+ "masked_pre_mul_sdpa" , use_mask = True , pre_scale = True , use_mul = True
81+ )
82+ masked_post_div_sdpa_rule = SDPA .rule (
83+ "masked_post_div_sdpa" , use_mask = True , pre_scale = False , use_mul = False
84+ )
85+ masked_post_mul_sdpa_rule = SDPA .rule (
86+ "masked_post_div_sdpa" , use_mask = True , pre_scale = False , use_mul = True
87+ )
6888
69- sdpa_rules = pattern .RewriteRuleSet ([masked_pre_mul_sdpa_rule , masked_post_div_sdpa_rule ])
89+ sdpa_rules = pattern .RewriteRuleSet (
90+ [
91+ masked_pre_mul_sdpa_rule ,
92+ masked_post_div_sdpa_rule ,
93+ masked_post_mul_sdpa_rule ,
94+ masked_pre_div_sdpa_rule ,
95+ ]
96+ )
97+
98+ debug : bool = True
7099
71100
72101def fuse_sdpa (model : ir .Model ) -> int :
73102 count = sdpa_rules .apply_to_model (model )
74- print (f"SDPA count: { count } " )
103+ if count == 0 and debug :
104+ sdpa_rules .apply_to_model (model , debug = True )
105+ else :
106+ print (f"SDPA count: { count } " )
75107 return count
0 commit comments