Skip to content

Commit f7b854c

Browse files
committed
Fix
1 parent 46c8d37 commit f7b854c

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/test_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import litgpt.config as config_module
3737
from litgpt import GPT, Config
38+
from litgpt.attention import DefaultKeysAndValues
3839
from litgpt.model import CausalSelfAttention
3940
from litgpt.scripts.convert_hf_checkpoint import (
4041
copy_weights_falcon,
@@ -1495,6 +1496,18 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14951496
# SDPAParams gained an additional argument in PyTorch 2.5
14961497
args = []
14971498
assert k_and_v.both_in_parallel()
1499+
# This is also done in `MultiHeadSelfAttention.scaled_dot_product_attention`
1500+
if mask is None and enable_gqa:
1501+
# Some efficient kernels have not implemented
1502+
# `enabla_gqa=True`. It is better to extend keys, values in
1503+
# this case.
1504+
key = k_and_v.keys()
1505+
value = k_and_v.values()
1506+
q_per_kv = config.n_head // config.n_query_groups
1507+
key = key.repeat_interleave(q_per_kv, dim=1)
1508+
value = value.repeat_interleave(q_per_kv, dim=1)
1509+
k_and_v = DefaultKeysAndValues(key, value)
1510+
14981511
if hasattr(SDPAParams, "enable_gqa"):
14991512
args.append(enable_gqa)
15001513
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)

0 commit comments

Comments
 (0)