File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -272,15 +272,24 @@ def scaled_dot_product_attention(
272
272
# in this case.
273
273
key = k_and_v .keys ()
274
274
value = k_and_v .values ()
275
+ is_causal = mask is None
276
+ enable_gqa = self .config .n_query_groups < self .config .n_head
277
+ if is_causal and enable_gqa :
278
+ # Some efficient kernels have not implemented
279
+ # `enabla_gqa=True`. It is better to extend keys, values in
280
+ # this case.
281
+ q_per_kv = self .config .n_head // self .config .n_query_groups
282
+ key = key .repeat_interleave (q_per_kv , dim = 1 )
283
+ value = value .repeat_interleave (q_per_kv , dim = 1 )
275
284
kwargs = dict (
276
285
query = query ,
277
286
key = key ,
278
287
value = value ,
279
288
attn_mask = mask ,
280
289
dropout_p = 0.0 ,
281
290
scale = scale ,
282
- is_causal = mask is None ,
283
- enable_gqa = self . config . n_query_groups < self . config . n_head ,
291
+ is_causal = is_causal ,
292
+ enable_gqa = enable_gqa ,
284
293
)
285
294
self ._filter_sdpa_kernels (** kwargs )
286
295
if self ._sdpa_kernels is not None :
You can’t perform that action at this time.
0 commit comments