22# Licensed under the MIT License.
33from __future__ import annotations
44
5- from typing import Sequence
5+ from typing import Sequence , Union
66
77import onnxscript .ir as ir
8- from onnxscript .rewriter import pattern
8+ from onnxscript .rewriter import _ir_utils , pattern
99
1010"""
11- The MultiHeadAttention pattern:
11+ The MultiHeadAttention pattern: generate an instance
12+ MHA (query, key, value, None, None, mask, past_key, past_value)
13+ where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv).
14+ The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias)
15+ must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh).
1216
17+ We use the following abbreviations for the dimensions:
1318B: Batch size
1419S: Sequence length
1520D: input embedding dimension
21+ Dv: value hidden size (usually, Dv = D)
1622H: number of heads
17- d_h: head size (usually, D = H * d_h)
23+ Dh: head size or embedding dimension per head (usually, D = H * Dh)
24+ Skv: key/value sequence length
25+ St: total sequence length
1826
19- thus, weights are usually of shape (D, D) and (D, D) and (D, D)
20-
21- for each of Q, K, and V, we have the following pattern:
22- MatMul (Input, W), producing output of shape (B, S, D)
23- Reshape to produce a matrix of shape (B, S, H, d_h)
24- Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)
25-
26- This is followed by a RotaryEmbedding pattern for Q and K
27-
28- The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)
29-
30- The dot-product attention is then computed using SDPA.
31- Finally, the output is transposed and reshaped back to (B, S, D) shape
27+ In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh).
28+ The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh).
3229"""
3330
31+ Dim = Union [int , ir .SymbolicDim ]
3432
35- def _check_shape (bindings : dict [str , int ], val : ir .Value , shape : Sequence [str ]) -> bool :
33+
34+ def _check_shape (bindings : dict [str , Dim ], val : ir .Value , shape : Sequence [str ]) -> bool :
3635 if val .shape is None :
3736 return False
3837 if val .shape .rank () != len (shape ):
@@ -46,131 +45,170 @@ def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str])
4645
4746
4847class MultiHeadAttention (pattern .RewriteRuleClassBase ):
49- def __init__ (self , name : str , * , use_2d_matmul : bool ):
50- super ().__init__ (name )
51- self ._use_2d_matmul = use_2d_matmul
52-
53- def _compute_QKV (self , op , input , weight , reshape_var : str ):
54- """Applied to generate each of Q, K, and V from input."""
55- if self ._use_2d_matmul :
56- # Convert batched input of shape (B, S, D) to 2D input (B*S, D)
57- input = op .Reshape (input , _allow_other_inputs = True )
58- projected = op .MatMul (input , weight )
59- if self ._use_2d_matmul :
60- # Convert 2D output back to batched output of shape (B, S, D)
61- projected = op .Reshape (projected , _allow_other_inputs = True )
62- # Reshape from (B, S, D) to (B, S, H, D/H)
63- reshaped = op .Reshape (
64- projected ,
65- _allow_other_inputs = True ,
66- _allow_other_attributes = True ,
67- _outputs = [reshape_var ],
68- )
69- # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
70- transposed = op .Transpose (reshaped , perm = [0 , 2 , 1 , 3 ])
71- return transposed
48+ def __init__ (self ):
49+ super ().__init__ ("MHA" )
7250
7351 def pattern (
7452 self ,
7553 op ,
76- input ,
77- query_weight ,
78- key_weight ,
79- value_weight ,
80- qkv_weight ,
54+ query_BSD ,
55+ key_BSD ,
56+ value_BSD ,
8157 mask ,
82- cos ,
83- sin ,
8458 past_key ,
8559 past_value ,
8660 position_ids ,
61+ cos ,
62+ sin ,
8763 ):
88- query = self ._compute_QKV (op , input , query_weight , "query_mm_reshaped" )
89- key = self ._compute_QKV (op , input , key_weight , "key_mm_reshaped" )
90- value = self ._compute_QKV (op , input , value_weight , "value_mm_reshaped" )
64+ # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
65+
66+ # Reshape from (B, S, D) to (B, S, H, D/H)
67+ query_BSHDh = op .Reshape (
68+ query_BSD ,
69+ _allow_other_inputs = True ,
70+ _allow_other_attributes = True ,
71+ _outputs = ["query_BSHDh" ],
72+ )
73+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
74+ query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
75+
76+ # Reshape from (B, S, D) to (B, S, H, D/H)
77+ key_BSHDh = op .Reshape (
78+ key_BSD ,
79+ _allow_other_inputs = True ,
80+ _allow_other_attributes = True ,
81+ _outputs = ["key_BSHDh" ],
82+ )
83+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84+ key_BHSDh = op .Transpose (key_BSHDh , perm = [0 , 2 , 1 , 3 ])
85+
86+ # Reshape from (B, S, D) to (B, S, H, D/H)
87+ value_BSHDh = op .Reshape (
88+ value_BSD ,
89+ _allow_other_inputs = True ,
90+ _allow_other_attributes = True ,
91+ _outputs = ["value_BSHDh" ],
92+ )
93+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
94+ value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
95+
96+ query_BHSDh_rope = op .RotaryEmbedding (
97+ query_BHSDh , position_ids , cos , sin , _domain = "com.microsoft"
98+ )
99+ key_BHSDh_rope = op .RotaryEmbedding (
100+ key_BHSDh , position_ids , cos , sin , _domain = "com.microsoft"
101+ )
91102
92- query_rope = op .RotaryEmbedding (query , position_ids , cos , sin , _domain = "com.microsoft" )
103+ # Concatenate past_key cache and current key, and transpose to enable
104+ # dot-product attention computation.
93105
94- key_rope = op .RotaryEmbedding (key , position_ids , cos , sin , _domain = "com.microsoft" )
95- key_rope = op .Concat (past_key , key_rope , axis = - 2 )
96- # Transpose last two axes of key_rope to compute dot-product via matmul.
97- key_reshaped = op .Reshape (
98- key_rope , _allow_other_inputs = True , _outputs = ["key_reshaped" ]
106+ key_seq = op .Concat (past_key , key_BHSDh_rope , axis = - 2 )
107+ # Transpose last two axes of key_seq to compute dot-product via matmul.
108+ key_seq_BH_Skv_Dh = op .Reshape (
109+ key_seq , _allow_other_inputs = True , _outputs = ["key_seq_BH_Skv_Dh" ]
99110 )
100- key_reshaped_transposed = op .Transpose (key_reshaped , perm = [0 , 2 , 1 ])
101- key_transposed = op .Reshape (
102- key_reshaped_transposed , _allow_other_inputs = True , _outputs = ["key_transposed " ]
111+ key_seq_BH_Dh_Skv = op .Transpose (key_seq_BH_Skv_Dh , perm = [0 , 2 , 1 ])
112+ key_seq_B_H_Dh_Skv = op .Reshape (
113+ key_seq_BH_Dh_Skv , _allow_other_inputs = True , _outputs = ["key_seq_B_H_Dh_Skv " ]
103114 )
104115
105- value = op .Concat (past_value , value , axis = - 2 )
116+ # Concatenate past_value cache and current value
117+ value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
106118
107119 attention = op .SDPA (
108- query_rope , key_transposed , value , mask , _domain = "ai.onnxruntime.fusion"
120+ query_BHSDh_rope ,
121+ key_seq_B_H_Dh_Skv ,
122+ value_seq ,
123+ mask ,
124+ _domain = "ai.onnxruntime.fusion" ,
109125 )
110- # Transpose back to (B, S, H, D/H)
126+
127+ # Transpose attention back to (B, S, H, D/H)
111128 attention_transposed = op .Transpose (attention , perm = [0 , 2 , 1 , 3 ])
112129 # Reshape back to (B, S, D)
113130 attention_reshaped = op .Reshape (
114131 attention_transposed , _allow_other_inputs = True , _outputs = ["attention_reshaped" ]
115132 )
116- return attention_reshaped , key_rope , value
133+ return attention_reshaped , key_seq , value_seq
117134
118135 def check (
119136 self ,
120137 op ,
121- query_mm_reshaped ,
122- key_mm_reshaped ,
123- value_mm_reshaped ,
124- key_reshaped ,
125- key_transposed ,
126- attention_reshaped ,
138+ query_BSD ,
139+ key_BSD ,
140+ value_BSD ,
141+ mask ,
142+ past_key ,
143+ past_value ,
144+ query_BSHDh ,
145+ key_BSHDh ,
146+ value_BSHDh ,
127147 ** _ ,
128148 ):
129- bindings : dict [str , int ] = {}
130- status = (
131- _check_shape ( bindings , query_mm_reshaped , [ "B" , "S" , "H" , "d_h" ])
132- and _check_shape (bindings , key_mm_reshaped , [ "B" , "S" , "H" , "d_h" ] )
133- and _check_shape ( bindings , value_mm_reshaped , [ "B" , "S" , "H" , "d_h" ])
134- and _check_shape ( bindings , key_reshaped , ["B*H " , "KVS " , "d_h " ])
135- and _check_shape ( bindings , key_transposed , [ "B" , "H" , "d_h" , "KVS" ])
136- and _check_shape ( bindings , attention_reshaped , ["B" , "S " , "H*d_h " ])
137- )
138- if not status :
149+ bindings : dict [str , Dim ] = {}
150+
151+ def no_match ( val : ir . Value , dims : Sequence [ str ]) -> bool :
152+ return not _check_shape (bindings , val , dims )
153+
154+ if no_match ( query_BSD , ["B" , "S " , "D " ]):
155+ return False
156+ if no_match ( key_BSD , ["B" , "Skv " , "D " ]):
157+ return False
158+ if no_match ( value_BSD , [ "B" , "Skv" , "D" ]) :
139159 return False
140- # if bindings["B"] * bindings["H"] != bindings["B*H"]:
141- # return False
142- # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
143- # return False
160+
161+ if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
162+ return False
163+ if no_match (past_value , ["B" , "H" , "Spast" , "Dv" ]):
164+ return False
165+ if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
166+ return False
167+ if no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
168+ return False
169+ if no_match (value_BSHDh , ["B" , "S" , "H" , "Dh" ]):
170+ return False
171+ # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
172+ # But this also, unforunately, depends on ORT version.
173+
174+ # TODO: verify Reshapes:
175+ # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
176+ # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
177+ # or check Reshape's shape-input value
144178 return True
145179
146180 def rewrite (
147181 self ,
148182 op ,
149- input ,
150- query_weight ,
151- key_weight ,
152- value_weight ,
183+ query_BSD ,
184+ key_BSD ,
185+ value_BSD ,
153186 mask ,
154- cos ,
155- sin ,
156187 past_key ,
157188 past_value ,
189+ key_BSHDh ,
158190 position_ids ,
159- query_mm_reshaped ,
191+ cos ,
192+ sin ,
160193 ** _ ,
161194 ):
162- num_heads = query_mm_reshaped .shape [2 ]
163- query = op .MatMul (input , query_weight )
164- key = op .MatMul (input , key_weight )
165- value = op .MatMul (input , value_weight )
166-
167- query_rope = op .RotaryEmbedding (query , position_ids , cos , sin , _domain = "com.microsoft" )
168- key_rope = op .RotaryEmbedding (key , position_ids , cos , sin , _domain = "com.microsoft" )
195+ num_heads = _ir_utils .get_dim (key_BSHDh , 2 )
196+ if not isinstance (num_heads , int ):
197+ return None
198+
199+ # Switch to 3D RotaryEmbedding
200+ # TODO: forward other attributes
201+ query_BSD_rope = op .RotaryEmbedding (
202+ query_BSD , position_ids , cos , sin , _domain = "com.microsoft"
203+ )
204+ key_BSD_rope = op .RotaryEmbedding (
205+ key_BSD , position_ids , cos , sin , _domain = "com.microsoft"
206+ )
169207
170208 return op .MultiHeadAttention (
171- query_rope ,
172- key_rope ,
173- value ,
209+ query_BSD_rope ,
210+ key_BSD_rope ,
211+ value_BSD ,
174212 None , # bias
175213 None , # key padding mask
176214 mask , # attention mask/bias
@@ -182,11 +220,15 @@ def rewrite(
182220 )
183221
184222
185- _rule1 = MultiHeadAttention .rule ("MHA_2dmm" , use_2d_matmul = False )
223+ _rule1 = MultiHeadAttention .rule ()
186224
187225mha_rules = pattern .RewriteRuleSet ([_rule1 ])
188226
189227
190- def fuse_mha (model : ir .Model ) -> int :
228+ def fuse_mha (model : ir .Model , * , debug : bool = False ) -> int :
191229 count = mha_rules .apply_to_model (model )
230+ if debug and count == 0 :
231+ tracer = pattern .MatchingTracer ()
232+ mha_rules .apply_to_model (model , tracer = tracer )
233+ tracer .report ()
192234 return count
0 commit comments