Skip to content

Commit 71f01bb

Browse files
committed
issue/942 - feat: add GQA python test for paged_attention_prefill
1 parent 180674d commit 71f01bb

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

test/infiniop/paged_attention_prefill.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
(4, 16, 16, 128, 8, 64, 3),
3131
(8, 64, 64, 128, 8, 16, 5),
3232
(16, 128, 128, 128, 8, 16, 4),
33+
(2, 8, 2, 128, 16, 32, 2),
34+
(4, 16, 4, 128, 8, 64, 3),
35+
(1, 64, 16, 128, 8, 16, 2),
3336
]
3437

3538
_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
@@ -78,6 +81,10 @@ def ref_paged_attention_multi_turn(
7881
query_new, k_cache, v_cache, block_tables, seq_lens, cum_seq_lens_q, scale
7982
):
8083
block_size = k_cache.shape[2]
84+
num_heads = query_new.shape[1]
85+
num_kv_heads = k_cache.shape[1]
86+
num_queries_per_kv = num_heads // num_kv_heads
87+
8188
outputs = torch.zeros_like(query_new)
8289
num_seqs = len(cum_seq_lens_q) - 1
8390
for i in range(num_seqs):
@@ -95,6 +102,10 @@ def ref_paged_attention_multi_turn(
95102

96103
K = torch.stack(keys_all, dim=0)
97104
V = torch.stack(values_all, dim=0)
105+
106+
if num_queries_per_kv > 1:
107+
K = torch.repeat_interleave(K, num_queries_per_kv, dim=1)
108+
V = torch.repeat_interleave(V, num_queries_per_kv, dim=1)
98109
Q = query_new[cum_seq_lens_q[i] : cum_seq_lens_q[i + 1], :, :]
99110

100111
scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale

0 commit comments

Comments
 (0)