88import onnx_ir as ir
99
1010import onnxscript .rewriter ._fusion_utils as _fusion_utils
11- from onnxscript .rewriter import _ir_utils , pattern
11+ from onnxscript .rewriter import _basics , _ir_utils , pattern
1212
1313"""
1414GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different
3232Dim = Union [int , ir .SymbolicDim ]
3333
3434
35- def causal_mask_pattern (op , input_ids , past_kv_cache , shape_B111 ):
35+ def _is_model_input (value : ir .Value , name : str , model : ir .Model ) -> bool :
36+ return value in model .graph .inputs and value .name == name
37+
38+
39+ def _causal_mask (
40+ op ,
41+ input_ids ,
42+ past_kv_cache ,
43+ shape_B111 ,
44+ min_val ,
45+ window_size ,
46+ dtype ,
47+ ):
48+ """Defines a pattern for a pure causal mask, with optional sliding window support."""
3649 seq_len = op .Shape (input_ids , end = 2 , start = 1 )
3750 seq_len_0D = op .Squeeze (seq_len )
3851
@@ -42,28 +55,93 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111):
4255 total_seq_len_0D = op .Add (past_seq_len_0D , seq_len_0D )
4356 total_seq_len = op .Reshape (total_seq_len_0D , [- 1 ])
4457
45- # The Phi modeling code generates the following +1 as the target-length, which seems
46- # unnecessary in this context. But using it for pattern-matching against
47- # generated onnx model.
48- total_seq_len_plus_1_0D = op .Add (total_seq_len_0D , 1 )
49- total_seq_len_plus_1 = op .Reshape (total_seq_len_plus_1_0D , [- 1 ])
50-
5158 current_range = op .Range (past_seq_len_0D , total_seq_len_0D , 1 )
52- mask_shape = op .Concat (seq_len , total_seq_len_plus_1 , axis = 0 )
53- min_float32 = float (np .finfo (np .float32 ).min )
54- mask_all_min = op .Expand (min_float32 , mask_shape )
55- total_range_as_row = op .Range (0 , total_seq_len_plus_1_0D , 1 )
59+ mask_shape = op .Concat (seq_len , total_seq_len , axis = 0 )
60+ mask_all_min_expand = op .Expand (min_val , mask_shape )
61+ # The following Trilu is optional: not used in Phi models, but used in LLama.
62+ mask_all_min_trilu = op .Trilu (mask_all_min_expand , 1 , upper = 1 )
63+ mask_all_min = pattern .OrValue ([mask_all_min_expand , mask_all_min_trilu ])
64+ total_range_as_row = op .Range (0 , total_seq_len_0D , 1 )
5665 current_range_as_column = op .Reshape (current_range , [- 1 , 1 ])
57- boolean_mask = op .Greater (total_range_as_row , current_range_as_column )
58- float_0_1_mask = op .Cast (boolean_mask , to = 1 )
66+
67+ non_causal = op .Greater (total_range_as_row , current_range_as_column )
68+
69+ # sliding window support:
70+ current_range_minus_window = op .Sub (current_range_as_column , window_size )
71+ out_of_sliding_window = op .LessOrEqual (total_range_as_row , current_range_minus_window )
72+ non_causal_sliding_window = op .Or (non_causal , out_of_sliding_window )
73+
74+ boolean_mask = pattern .OrValue ([non_causal , non_causal_sliding_window ])
75+
76+ float_0_1_mask = op .Cast (boolean_mask , to = dtype )
5977 float_0_min_mask = op .Mul (mask_all_min , float_0_1_mask )
60- mask_4d = op .Unsqueeze (float_0_min_mask , [0 , 1 ])
61- mask_B1ST_plus = op .Expand (mask_4d , shape_B111 )
78+ mask_4d_11ST = op .Unsqueeze (float_0_min_mask , [0 , 1 ])
79+ mask_4d_B1ST = op .Expand (mask_4d_11ST , shape_B111 )
80+
81+ return mask_4d_B1ST
82+
83+
84+ class _CausalMaskPattern (pattern .PatternBase ):
85+ def pattern (
86+ self ,
87+ op ,
88+ input_ids ,
89+ past_kv_cache ,
90+ shape_B111 ,
91+ min_val ,
92+ window_size ,
93+ dtype1 ,
94+ attn_mask_2d ,
95+ dtype2 ,
96+ ):
97+ causal_mask = _causal_mask (
98+ op ,
99+ input_ids ,
100+ past_kv_cache ,
101+ shape_B111 ,
102+ min_val ,
103+ window_size ,
104+ dtype1 ,
105+ )
106+
107+ attn_mask_4d = op .Unsqueeze (attn_mask_2d , [1 , 2 ])
108+ attn_mask_4d_cast = op .Cast (attn_mask_4d , to = dtype2 )
109+
110+ sum = op .Add (causal_mask , attn_mask_4d_cast )
111+ sum_fp32 = op .Cast (sum , to = ir .DataType .FLOAT )
112+ # The cast is optional, and may be absent if the sum is already in float32.
113+ sum_fp32 = pattern .OrValue ([sum_fp32 , sum ])
114+ is_zero = op .Equal (sum_fp32 , 0.0 )
115+ result = op .Where (is_zero , min_val , causal_mask )
116+ return result
117+
118+ def check (self , context , dtype1 , dtype2 , min_val , attn_mask_2d , sliding_window = None , ** _ ):
119+ # Check that attn_mask_2d is the model input "attention_mask"
120+ if not _is_model_input (attn_mask_2d , "attention_mask" , context .model ):
121+ return pattern .MatchResult ().fail ("Invalid attention_mask input" , attn_mask_2d )
122+
123+ if dtype1 .as_int () != dtype2 .as_int ():
124+ return pattern .MatchResult ().fail ("Dtype mismatch" , [dtype1 , dtype2 ])
125+
126+ # Check that min_val is a constant and matches the expected minimum value for the dtype.
127+ min_value = _ir_utils .get_singleton_value (min_val )
128+ if min_value is None :
129+ return pattern .MatchResult ().fail ("Minval is not a constant." , min_val )
130+ expected_min_value = np .finfo (min_val .dtype .numpy ()).min
131+ if min_value != expected_min_value :
132+ return pattern .MatchResult ().fail (
133+ f"Expected min value { expected_min_value } , got { min_value } " , min_val
134+ )
135+
136+ # TODO(rama) Sliding window: not yet supported.
137+ if sliding_window :
138+ return pattern .MatchResult ().fail (
139+ "Sliding window not yet supported" , sliding_window
140+ )
141+ return True
62142
63- # Get rid of the extra +1 added above: total_seq_len is enough, no
64- # need for total_seq_len+1.
65- mask_B1ST = op .Slice (mask_B1ST_plus , [0 ], total_seq_len , [3 ], [1 ])
66- return mask_B1ST
143+
144+ _causal_mask_pattern = _CausalMaskPattern ()
67145
68146
69147class GroupQueryAttention (pattern .RewriteRuleClassBase ):
@@ -78,8 +156,7 @@ def pattern(
78156 value_BSDkv ,
79157 past_key ,
80158 past_value ,
81- position_ids_q ,
82- position_ids_k ,
159+ position_ids ,
83160 cos ,
84161 sin ,
85162 mask ,
@@ -101,15 +178,15 @@ def pattern(
101178
102179 query_BHSDh_rope = op .RotaryEmbedding (
103180 query_BHSDh ,
104- position_ids_q ,
181+ position_ids ,
105182 cos ,
106183 sin ,
107184 _domain = "com.microsoft" ,
108185 _outputs = ["query_BHSDh_rope" ],
109186 )
110187 key_BHkvSDh_rope = op .RotaryEmbedding (
111188 key_BHkvSDh ,
112- position_ids_k ,
189+ position_ids ,
113190 cos ,
114191 sin ,
115192 _domain = "com.microsoft" ,
@@ -154,7 +231,7 @@ def pattern(
154231
155232 def check (
156233 self ,
157- op ,
234+ context : _basics . MatchContext ,
158235 query_BSD ,
159236 key_BSDkv ,
160237 value_BSDkv ,
@@ -164,6 +241,7 @@ def check(
164241 key_BHkvSDh_rope ,
165242 query_BSHDh ,
166243 key_BSHkvDh ,
244+ mask ,
167245 ** _ ,
168246 ):
169247 bindings : dict [str , Dim ] = {}
@@ -210,6 +288,20 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
210288 )
211289 self ._interleaved = query_interleaved
212290
291+ # Check mask:
292+ mask_node = mask .producer ()
293+ if mask_node is None :
294+ return pattern .MatchResult ().fail ("Unhandled mask pattern" , mask )
295+ mask_match_result = _causal_mask_pattern .match (
296+ context .model ,
297+ context .graph_or_function ,
298+ mask_node ,
299+ check_nodes_are_removable = False ,
300+ )
301+ if mask_match_result is None :
302+ return pattern .MatchResult ().fail ("Mask does not match causal mask pattern" , mask )
303+ # TODO: handle sliding window support in mask
304+
213305 return True
214306
215307 def rewrite (
@@ -220,104 +312,51 @@ def rewrite(
220312 value_BSDkv ,
221313 past_key ,
222314 past_value ,
223- position_ids_q ,
224- position_ids_k ,
315+ position_ids ,
225316 cos ,
226317 sin ,
227318 mask ,
228319 ** _ ,
229320 ):
230- return op .GQA (
231- mask ,
232- position_ids_k ,
233- position_ids_q ,
321+ # Note that the following optimization is specific to current ORT GenAI attention-mask
322+ # usage. Specifically, it assumes that the model-input "attention_mask" is a 2D
323+ # mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask
324+ # that is used only to indicate the current tokens. Hence, the input attention_mask
325+ # is redundant as long as past-sequence-length and current-sequence-length can be
326+ # computed.
327+
328+ # Construct seqlens_k and total_seq_length_int32 from position_ids
329+ # seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch
330+ # position_ids: int64[batch_size, sequence_length] indicates the position of each token
331+ one_int32_0d = op .Constant (value = ir .tensor (1 , dtype = ir .DataType .INT32 ))
332+ one_int64_1d = op .Constant (value = ir .tensor ([1 ], dtype = ir .DataType .INT64 ))
333+ zero_int64_1d = op .Constant (value = ir .tensor ([0 ], dtype = ir .DataType .INT64 ))
334+ seqlens_k_int64 = op .ReduceMax (position_ids , one_int64_1d , keepdims = 0 )
335+ seqlens_k = op .Cast (seqlens_k_int64 , to = ir .DataType .INT32 )
336+ max_seq_length = op .ReduceMax (seqlens_k , zero_int64_1d , keepdims = 0 )
337+ total_seq_length_int32 = op .Add (max_seq_length , one_int32_0d )
338+ return op .GroupQueryAttention (
234339 query_BSD ,
235340 key_BSDkv ,
236341 value_BSDkv ,
237342 past_key ,
238343 past_value ,
239- None , # seqlens_k,
240- None , # total_seq_length_int32,
344+ seqlens_k ,
345+ total_seq_length_int32 ,
241346 cos ,
242347 sin ,
243348 num_heads = self .num_heads ,
244349 kv_num_heads = self .kv_num_heads ,
245350 do_rotary = 1 ,
246351 rotary_interleaved = self ._interleaved ,
247352 # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
248- _domain = "ai.onnxruntime._fusion " ,
353+ _domain = "com.microsoft " ,
249354 _outputs = 3 ,
250355 )
251356
252357
253- class GQACausalMask (pattern .RewriteRuleClassBase ):
254- def __init__ (self ):
255- super ().__init__ ("GQACausalMask" , remove_nodes = False )
256-
257- def pattern (
258- self ,
259- op ,
260- mask ,
261- input_ids ,
262- some_kv_cache ,
263- shape_B111 ,
264- past_seq_length ,
265- total_seq_length ,
266- ):
267- mask = causal_mask_pattern (op , input_ids , some_kv_cache , shape_B111 )
268- position_ids = op .Range (past_seq_length , total_seq_length , 1 )
269- position_ids_q = op .Unsqueeze (position_ids , [0 ])
270- position_ids_k = op .Unsqueeze (position_ids , [0 ])
271- return op .GQA (
272- mask ,
273- position_ids_k ,
274- position_ids_q ,
275- _allow_other_inputs = True ,
276- _domain = "ai.onnxruntime._fusion" ,
277- _outputs = ["attn_output" , "key_seq" , "value_seq" ],
278- )
279-
280- def rewrite (
281- self ,
282- op ,
283- total_seq_length ,
284- attn_output ,
285- ** _ ,
286- ):
287- # Construct total_seq_length_int32 and seqlens_k
288- total_seq_length_int32 = op .Cast (total_seq_length , to = ir .DataType .INT32 )
289- one_0D = op .Constant (value_int = 1 )
290- one_0D_int32 = op .Cast (one_0D , to = ir .DataType .INT32 )
291- seqlens_k_0D = op .Sub (total_seq_length_int32 , one_0D_int32 )
292- zero_1D = op .Constant (value_int = 0 , dtype = ir .DataType .INT64 , shape = [1 ])
293- seqlens_k = op .Unsqueeze (seqlens_k_0D , zero_1D )
294-
295- gqa_node = attn_output .producer ()
296- assert len (gqa_node .inputs ) == 12 , (
297- f"Expected 12 inputs for GQA node, got { len (gqa_node .inputs )} "
298- )
299- query , key , value , past_key , past_value = gqa_node .inputs [3 :8 ]
300- cos , sin = gqa_node .inputs [10 :12 ]
301- updated_inputs = [
302- query ,
303- key ,
304- value ,
305- past_key ,
306- past_value ,
307- seqlens_k ,
308- total_seq_length_int32 ,
309- cos ,
310- sin ,
311- ]
312- attributes = gqa_node .attributes
313- return op .GroupQueryAttention (
314- * updated_inputs , ** attributes , _domain = "com.microsoft" , _outputs = 3
315- )
316-
317-
318358_basic_gqa_rule = GroupQueryAttention .rule ()
319- _gqa_causal_mask_rule = GQACausalMask .rule ()
320359
321- gqa_rules = pattern .RewriteRuleSet ([_basic_gqa_rule , _gqa_causal_mask_rule ])
360+ gqa_rules = pattern .RewriteRuleSet ([_basic_gqa_rule ])
322361
323362fuse_gqa = _fusion_utils .apply_fusion_rules (gqa_rules )
0 commit comments