Skip to content

Commit d2e9e45

Browse files
committed
Set enable_gqa flag in scaled_dot_product_attention
1 parent 7953793 commit d2e9e45

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

litgpt/model.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -677,30 +677,16 @@ def scaled_dot_product_attention(
677677
# in this case.
678678
key = k_and_v.keys()
679679
value = k_and_v.values()
680-
for retry in range(2):
681-
try:
682-
y = F.scaled_dot_product_attention(
683-
query=q,
684-
key=key,
685-
value=value,
686-
attn_mask=mask,
687-
dropout_p=0.0,
688-
scale=scale,
689-
is_causal=is_causal,
690-
)
691-
break
692-
except RuntimeError as ex:
693-
if retry == 1 or self.config.n_query_groups == self.config.n_head:
694-
raise ex # Re-throw
695-
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
696-
# `scaled_dot_product_attention` is supposed to support
697-
# `query.shape = (bs, nh_q, ...), key.shape = (bs, nh_k, ...)`
698-
# and `nh_k < nh_q` if `nh_q` is a multiple of `nh_k`. But
699-
# this seems not yet supported (in 2.5.1), so have to lift
700-
# K, V here. This is annoying, as it wastes memory.
701-
q_per_kv = self.config.n_head // self.config.n_query_groups
702-
key = key.repeat_interleave(q_per_kv, dim=1)
703-
value = value.repeat_interleave(q_per_kv, dim=1)
680+
y = F.scaled_dot_product_attention(
681+
query=q,
682+
key=key,
683+
value=value,
684+
attn_mask=mask,
685+
dropout_p=0.0,
686+
scale=scale,
687+
is_causal=is_causal,
688+
enable_gqa=self.config.n_query_groups < self.config.n_head,
689+
)
704690
scores = None
705691
return y.transpose(1, 2), scores
706692

0 commit comments

Comments
 (0)