@@ -38,16 +38,12 @@ def __init__(
3838 name ,
3939 * ,
4040 double_transpose : bool ,
41- transpose_4d : bool ,
42- pre_scale_q : bool ,
4341 is_rotary : bool ,
4442 has_past_present : bool ,
4543 is_cross_attention : bool ,
4644 ):
4745 super ().__init__ (name )
4846 self ._double_transpose = double_transpose
49- self ._transpose_4d = transpose_4d
50- self ._pre_scale_q = pre_scale_q
5147 self ._is_rotary = is_rotary
5248 self ._has_past_present = has_past_present
5349 self ._is_cross_attention = is_cross_attention
@@ -63,12 +59,9 @@ def pattern(
6359 position_ids ,
6460 cos ,
6561 sin ,
66- q_scale ,
6762 ):
6863 # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
6964
70- if self ._pre_scale_q :
71- query_BSD = op .Mul (query_BSD , q_scale )
7265 # Reshape from (B, S, D) to (B, S, H, D/H)
7366 query_BSHDh = op .Reshape (query_BSD , pattern .ANY_VALUE , _outputs = ["query_BSHDh" ])
7467 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
@@ -93,24 +86,12 @@ def pattern(
9386 value_BHSDh = value
9487
9588 if self ._is_rotary :
96- # This is workaround for examples where there is a duplication of Unsqueeze op
97- # to generate a 2D positions-ids from a 1D position-ids. This can be eliminated
98- # if we have CSE-optimization to eliminate the duplicate Unsqueeze ops.
99- # For now, same flag (transpose_4d) controls this variation. A different flag
100- # can be added if we see instances that mix the two.
101- if self ._transpose_4d :
102- position_ids_q = op .Unsqueeze (position_ids , [0 ])
103- position_ids_k = op .Unsqueeze (position_ids , [0 ])
104- else :
105- position_ids_q = position_ids
106- position_ids_k = position_ids
107-
10889 query_BHSDh_emb = op .RotaryEmbedding (
109- query_BHSDh , position_ids_q , cos , sin , _domain = "com.microsoft"
90+ query_BHSDh , position_ids , cos , sin , _domain = "com.microsoft"
11091 )
11192 if not self ._is_cross_attention :
11293 key_BHSDh_emb = op .RotaryEmbedding (
113- key , position_ids_k , cos , sin , _domain = "com.microsoft"
94+ key , position_ids , cos , sin , _domain = "com.microsoft"
11495 )
11596 else :
11697 key_BHSDh_emb = key
@@ -289,6 +270,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
289270 else :
290271 self ._use_mask_broadcast = False
291272
273+ self ._scale = sdpa_node .attributes .get_float ("scale" , None )
292274 # TODO: verify Reshapes:
293275 # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
294276 # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
@@ -307,20 +289,14 @@ def rewrite(
307289 position_ids ,
308290 cos ,
309291 sin ,
310- q_scale = None ,
311292 ** _ ,
312293 ):
313- scale = _ir_utils .get_singleton_value (q_scale )
314294 num_heads = _ir_utils .get_dim (query_BSHDh , 2 )
315295 if not isinstance (num_heads , int ):
316296 return None
317297
318298 # TODO: forward other attributes
319299
320- if self ._transpose_4d :
321- zero_1d = op .Constant (value_ints = [0 ])
322- position_ids = op .Unsqueeze (position_ids , zero_1d )
323-
324300 if self ._is_rotary :
325301 query_BSD_emb = op .RotaryEmbedding (
326302 query_BSD , position_ids , cos , sin , _domain = "com.microsoft"
@@ -360,27 +336,21 @@ def rewrite(
360336 past_key ,
361337 past_value ,
362338 num_heads = num_heads ,
363- scale = scale ,
364339 _domain = "com.microsoft" ,
365340 _outputs = num_outputs ,
341+ scale = self ._scale ,
366342 )
367343
368344
369345def _make_rule_set (has_past_present : bool ):
370346 parameter_combinations = [
371347 {
372348 "double_transpose" : double_transpose ,
373- "transpose_4d" : transpose_4d ,
374- "pre_scale_q" : pre_scale_q ,
375349 "is_rotary" : is_rotary ,
376350 "has_past_present" : has_past_present ,
377351 "is_cross_attention" : is_cross_attention ,
378352 }
379353 for double_transpose in [False , True ]
380- for transpose_4d in (
381- [False , True ] if double_transpose else [False ]
382- ) # Only generate patterns when double_transpose is True
383- for pre_scale_q in [True , False ]
384354 for is_rotary in [False , True ]
385355 for is_cross_attention in ([False ] if has_past_present else [False , True ])
386356 ]
@@ -389,9 +359,8 @@ def _make_rule_set(has_past_present: bool):
389359 mha_rules = pattern .RewriteRuleSet (
390360 [
391361 MultiHeadAttention .rule (
392- f"MHA_ { '4D' if params [ 'transpose_4d' ] else '3D' } _Transpose "
362+ f"MHA "
393363 f"{ '_Twice' if params ['double_transpose' ] else '' } "
394- f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
395364 f"{ '_Rotary' if params ['is_rotary' ] else '' } "
396365 f"{ '_Past' if params ['has_past_present' ] else '' } "
397366 f"{ '_CrossAttention' if params ['is_cross_attention' ] else '' } " ,
0 commit comments