@@ -18,6 +18,7 @@ def _fwd_kernel_with_v(
1818 V ,
1919 sm_scale ,
2020 B_Start_Loc ,
21+ B_Kv_Start_Loc ,
2122 B_Seqlen , # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度
2223 Out ,
2324 stride_q_bs ,
@@ -44,7 +45,8 @@ def _fwd_kernel_with_v(
4445
4546 cur_k_head = cur_head
4647
47- cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
48+ cur_batch_in_q_start_index = tl .load (B_Start_Loc + cur_batch )
49+ cur_batch_in_kv_start_index = tl .load (B_Kv_Start_Loc + cur_batch )
4850 prompt_cache_len = tl .load (b_prompt_cache_len + cur_batch )
4951 cur_batch_seq_len = tl .load (B_Seqlen + cur_batch ) - prompt_cache_len
5052
@@ -55,9 +57,9 @@ def _fwd_kernel_with_v(
5557 offs_d = tl .arange (0 , BLOCK_DMODEL )
5658 offs_rope_d = tl .arange (0 , BLOCK_ROPE_DMODEL )
5759 offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
58- off_q = (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_q_bs + cur_head * stride_q_h + offs_d [None , :]
60+ off_q = (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_q_bs + cur_head * stride_q_h + offs_d [None , :]
5961 off_q_rope = (
60- (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_q_rope_bs
62+ (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_q_rope_bs
6163 + cur_head * stride_q_rope_h
6264 + offs_rope_d [None , :]
6365 )
@@ -84,12 +86,12 @@ def _fwd_kernel_with_v(
8486 start_n = tl .multiple_of (start_n , BLOCK_N )
8587 # -- compute qk ----
8688 k = tl .load (
87- k_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_k_bs ,
89+ k_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_k_bs ,
8890 mask = (start_n + offs_n [None , :]) < block_end_loc ,
8991 other = 0.0 ,
9092 )
9193 k_rope = tl .load (
92- k_rope_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_k_rope_bs ,
94+ k_rope_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_k_rope_bs ,
9395 mask = (start_n + offs_n [None , :]) < block_end_loc ,
9496 other = 0.0 ,
9597 )
@@ -119,7 +121,7 @@ def _fwd_kernel_with_v(
119121 acc = acc * acc_scale [:, None ]
120122 # update acc
121123 v = tl .load (
122- v_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_vbs ,
124+ v_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_vbs ,
123125 mask = (start_n + offs_n [:, None ]) < block_end_loc ,
124126 other = 0.0 ,
125127 )
@@ -129,7 +131,7 @@ def _fwd_kernel_with_v(
129131 l_i = l_i_new
130132 m_i = m_i_new
131133 # initialize pointers to output
132- off_o = (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_obs + cur_head * stride_oh + offs_d [None , :]
134+ off_o = (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_obs + cur_head * stride_oh + offs_d [None , :]
133135 out_ptrs = Out + off_o
134136 tl .store (out_ptrs , acc , mask = offs_m [:, None ] < cur_batch_seq_len )
135137 return
@@ -144,6 +146,7 @@ def context_attention_fwd_with_v(
144146 v ,
145147 o ,
146148 b_start_loc ,
149+ b_kv_start_loc ,
147150 b_seq_len ,
148151 b_prompt_cache_len ,
149152 max_input_len ,
@@ -181,6 +184,7 @@ def context_attention_fwd_with_v(
181184 v ,
182185 sm_scale ,
183186 b_start_loc ,
187+ b_kv_start_loc ,
184188 b_seq_len ,
185189 o ,
186190 q_nope .stride (0 ),
@@ -204,3 +208,78 @@ def context_attention_fwd_with_v(
204208 num_stages = 1 ,
205209 )
206210 return
211+
212+
213+ if __name__ == "__main__" :
214+ import torch
215+ import flashinfer
216+
217+ Z , N_CTX , H , D_HEAD , ROPE_HEAD = 32 , 1024 , 16 , 128 , 64
218+ dtype = torch .bfloat16
219+
220+ k_nope = torch .randn ((Z * N_CTX , H , D_HEAD ), dtype = dtype , device = "cuda" )
221+ k_rope = torch .randn ((Z * N_CTX , 1 , ROPE_HEAD ), dtype = dtype , device = "cuda" )
222+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , H , dim = - 2 )], dim = - 1 )
223+ v = torch .randn ((Z * N_CTX , H , D_HEAD ), dtype = dtype , device = "cuda" )
224+
225+ max_input_len = Z * N_CTX
226+ softmax_scale = 0.117
227+ b_seq_len = torch .ones ((Z ,), dtype = torch .int32 , device = "cuda" ) * N_CTX
228+ b_prompt_cache_len = torch .zeros_like (b_seq_len , dtype = torch .int32 , device = "cuda" )
229+ b_prompt_cache_len = torch .randint_like (b_seq_len , high = N_CTX - 1 , dtype = torch .int32 , device = "cuda" )
230+ q_lens = b_seq_len - b_prompt_cache_len
231+ q_start_loc = q_lens .cumsum (0 ) - q_lens
232+ kv_start_loc = b_seq_len .cumsum (0 ) - b_seq_len
233+
234+ q_nope = torch .randn ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
235+ q_rope = torch .randn ((q_lens .sum (), H , ROPE_HEAD ), dtype = dtype , device = "cuda" )
236+ q = torch .cat ([q_nope , q_rope ], dim = - 1 )
237+
238+ o = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
239+ o1 = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
240+ o2 = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
241+
242+ fn1 = lambda : context_attention_fwd_with_v (
243+ q_nope ,
244+ q_rope ,
245+ k_nope ,
246+ k_rope ,
247+ v ,
248+ o ,
249+ q_start_loc ,
250+ kv_start_loc ,
251+ b_seq_len ,
252+ b_prompt_cache_len ,
253+ max_input_len ,
254+ softmax_scale ,
255+ )
256+
257+ q_starts = torch .zeros ((Z + 1 ,)).int ().cuda ()
258+ q_starts [1 :] = torch .cumsum (b_seq_len - b_prompt_cache_len , dim = 0 )
259+ kv_starts = torch .zeros_like (q_starts )
260+ kv_starts [1 :] = torch .cumsum (b_seq_len , dim = 0 )
261+ kv_layout = "NHD"
262+ batch_size = Z
263+ q_indptr = q_starts
264+ kv_indptr = kv_starts
265+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 ).to (0 )
266+ wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (workspace_buffer , kv_layout )
267+ wrapper .plan (
268+ qo_indptr = q_indptr ,
269+ kv_indptr = kv_indptr ,
270+ num_qo_heads = H ,
271+ num_kv_heads = H ,
272+ head_dim_qk = D_HEAD + ROPE_HEAD ,
273+ head_dim_vo = D_HEAD ,
274+ q_data_type = dtype ,
275+ causal = True ,
276+ sm_scale = softmax_scale ,
277+ )
278+ fn2 = lambda : wrapper .run (q , k , v , out = o1 )
279+
280+ ms1 = triton .testing .do_bench (fn1 )
281+ ms2 = triton .testing .do_bench (fn2 )
282+ cos_sim1 = F .cosine_similarity (o , o1 ).mean ()
283+ print (cos_sim1 )
284+ print (ms1 )
285+ print (ms2 )
0 commit comments