3030 (4 , 16 , 16 , 128 , 8 , 64 , 3 ),
3131 (8 , 64 , 64 , 128 , 8 , 16 , 5 ),
3232 (16 , 128 , 128 , 128 , 8 , 16 , 4 ),
33+ (2 , 8 , 2 , 128 , 16 , 32 , 2 ),
34+ (4 , 16 , 4 , 128 , 8 , 64 , 3 ),
35+ (1 , 64 , 16 , 128 , 8 , 16 , 2 ),
3336]
3437
3538_TENSOR_DTYPES = [InfiniDtype .F32 , InfiniDtype .BF16 , InfiniDtype .F16 ]
@@ -78,6 +81,10 @@ def ref_paged_attention_multi_turn(
7881 query_new , k_cache , v_cache , block_tables , seq_lens , cum_seq_lens_q , scale
7982):
8083 block_size = k_cache .shape [2 ]
84+ num_heads = query_new .shape [1 ]
85+ num_kv_heads = k_cache .shape [1 ]
86+ num_queries_per_kv = num_heads // num_kv_heads
87+
8188 outputs = torch .zeros_like (query_new )
8289 num_seqs = len (cum_seq_lens_q ) - 1
8390 for i in range (num_seqs ):
@@ -95,6 +102,10 @@ def ref_paged_attention_multi_turn(
95102
96103 K = torch .stack (keys_all , dim = 0 )
97104 V = torch .stack (values_all , dim = 0 )
105+
106+ if num_queries_per_kv > 1 :
107+ K = torch .repeat_interleave (K , num_queries_per_kv , dim = 1 )
108+ V = torch .repeat_interleave (V , num_queries_per_kv , dim = 1 )
98109 Q = query_new [cum_seq_lens_q [i ] : cum_seq_lens_q [i + 1 ], :, :]
99110
100111 scores = torch .einsum ("qhd,khd->hqk" , Q , K ).float () * scale
0 commit comments