Skip to content

Commit 46c8d37

Browse files
committed
Fix
1 parent 762706e commit 46c8d37

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

litgpt/attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,24 @@ def scaled_dot_product_attention(
272272
# in this case.
273273
key = k_and_v.keys()
274274
value = k_and_v.values()
275+
is_causal = mask is None
276+
enable_gqa = self.config.n_query_groups < self.config.n_head
277+
if is_causal and enable_gqa:
278+
# Some efficient kernels have not implemented
279+
# `enabla_gqa=True`. It is better to extend keys, values in
280+
# this case.
281+
q_per_kv = self.config.n_head // self.config.n_query_groups
282+
key = key.repeat_interleave(q_per_kv, dim=1)
283+
value = value.repeat_interleave(q_per_kv, dim=1)
275284
kwargs = dict(
276285
query=query,
277286
key=key,
278287
value=value,
279288
attn_mask=mask,
280289
dropout_p=0.0,
281290
scale=scale,
282-
is_causal=mask is None,
283-
enable_gqa=self.config.n_query_groups < self.config.n_head,
291+
is_causal=is_causal,
292+
enable_gqa=enable_gqa,
284293
)
285294
self._filter_sdpa_kernels(**kwargs)
286295
if self._sdpa_kernels is not None:

0 commit comments

Comments
 (0)