Skip to content

Commit d659391

Browse files
committed
Fix
1 parent f7b854c commit d659391

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,19 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
14441444
# SDPAParams gained an additional argument in PyTorch 2.5
14451445
args = []
14461446
assert k_and_v.both_in_parallel()
1447+
# This is also done in `MultiHeadSelfAttention.scaled_dot_product_attention`
1448+
if mask is None and enable_gqa:
1449+
# Some efficient kernels have not implemented
1450+
# `enabla_gqa=True`. It is better to extend keys, values in
1451+
# this case.
1452+
key = k_and_v.keys()
1453+
value = k_and_v.values()
1454+
q_per_kv = config.n_head // config.n_query_groups
1455+
key = key.repeat_interleave(q_per_kv, dim=1)
1456+
value = value.repeat_interleave(q_per_kv, dim=1)
1457+
assert query.shape[1] == key.shape[1]
1458+
k_and_v = DefaultKeysAndValues(key, value)
1459+
14471460
if hasattr(SDPAParams, "enable_gqa"):
14481461
args.append(enable_gqa)
14491462
params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
@@ -1506,6 +1519,7 @@ def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores):
15061519
q_per_kv = config.n_head // config.n_query_groups
15071520
key = key.repeat_interleave(q_per_kv, dim=1)
15081521
value = value.repeat_interleave(q_per_kv, dim=1)
1522+
assert query.shape[1] == key.shape[1]
15091523
k_and_v = DefaultKeysAndValues(key, value)
15101524

15111525
if hasattr(SDPAParams, "enable_gqa"):

0 commit comments

Comments
 (0)