@@ -185,6 +185,184 @@ def _fwd_context_paged_attention_kernel(
185
185
return
186
186
187
187
188
+ # Triton 2.1.0
189
+ # TODO(yuanheng-zhao): This is a temporary dispatch to use the new layout for kcache
190
+ # merge `_fwd_context_paged_attention_kernel_v2` with `_fwd_context_paged_attention_kernel` later
191
+ # as the kcache layout has been supported in the whole triton flow.
192
+ @triton .jit
193
+ def _fwd_context_paged_attention_kernel_v2 (
194
+ Q ,
195
+ K ,
196
+ V ,
197
+ O ,
198
+ KCache , # [num_blocks, num_kv_heads, head_dim // x, block_size, x]
199
+ VCache , # [num_blocks, num_kv_heads, block_size, head_dim]
200
+ BLOCK_TABLES , # [num_seqs, max_blocks_per_sequence]
201
+ batch_size ,
202
+ stride_qt ,
203
+ stride_qh ,
204
+ stride_qd ,
205
+ stride_kt ,
206
+ stride_kh ,
207
+ stride_kd ,
208
+ stride_vt ,
209
+ stride_vh ,
210
+ stride_vd ,
211
+ stride_ot ,
212
+ stride_oh ,
213
+ stride_od ,
214
+ stride_cacheb , # v cache stride(0) - num_blocks
215
+ stride_cacheh , # v cache stride(1) - num_kv_heads
216
+ stride_cachebs , # v cache stride(2) - block_size
217
+ stride_cached , # v cache stride(3) - head_dim
218
+ stride_bts ,
219
+ stride_btb ,
220
+ context_lengths ,
221
+ sm_scale ,
222
+ KV_GROUPS : tl .constexpr ,
223
+ BLOCK_SIZE : tl .constexpr ,
224
+ HEAD_DIM : tl .constexpr ,
225
+ KCACHE_X : tl .constexpr , # k stride on the second last dimension
226
+ BLOCK_M : tl .constexpr ,
227
+ BLOCK_N : tl .constexpr ,
228
+ ):
229
+ cur_seq_idx = tl .program_id (0 )
230
+ if cur_seq_idx >= batch_size :
231
+ return
232
+ cur_head_idx = tl .program_id (1 )
233
+ block_start_m = tl .program_id (2 ) # Br, max_input_len // Block_M
234
+ cur_kv_head_idx = cur_head_idx // KV_GROUPS
235
+
236
+ # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same
237
+ tl .static_assert (BLOCK_M == BLOCK_N )
238
+ tl .static_assert (BLOCK_N == BLOCK_SIZE )
239
+
240
+ # get the current sequence length from provided context lengths tensor
241
+ cur_seq_len = tl .load (context_lengths + cur_seq_idx )
242
+ # NOTE when talking to fused QKV and a nopadding context attention,
243
+ # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum`
244
+ # could be considered as the start index of the current sequence.
245
+ # FIXME might want to explore better way to get the summation of prev seq lengths.
246
+ # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton.
247
+ prev_seq_len_sum = 0
248
+ for i in range (0 , cur_seq_idx ):
249
+ prev_seq_len_sum += tl .load (context_lengths + i )
250
+
251
+ offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
252
+ offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
253
+ Q_block_ptr = tl .make_block_ptr (
254
+ base = Q + offset_q ,
255
+ shape = (cur_seq_len , HEAD_DIM ),
256
+ strides = (stride_qt , stride_qd ),
257
+ offsets = (block_start_m * BLOCK_M , 0 ),
258
+ block_shape = (BLOCK_M , HEAD_DIM ),
259
+ order = (1 , 0 ),
260
+ )
261
+ K_block_ptr = tl .make_block_ptr (
262
+ base = K + offset_kv ,
263
+ shape = (HEAD_DIM , cur_seq_len ),
264
+ strides = (stride_kd , stride_kt ),
265
+ offsets = (0 , 0 ),
266
+ block_shape = (HEAD_DIM , BLOCK_N ),
267
+ order = (0 , 1 ),
268
+ )
269
+ V_block_ptr = tl .make_block_ptr (
270
+ base = V + offset_kv ,
271
+ shape = (cur_seq_len , HEAD_DIM ),
272
+ strides = (stride_vt , stride_vd ),
273
+ offsets = (0 , 0 ),
274
+ block_shape = (BLOCK_N , HEAD_DIM ),
275
+ order = (1 , 0 ),
276
+ )
277
+ O_block_ptr = tl .make_block_ptr (
278
+ base = O + offset_q ,
279
+ shape = (cur_seq_len , HEAD_DIM ),
280
+ strides = (stride_ot , stride_od ),
281
+ offsets = (block_start_m * BLOCK_M , 0 ),
282
+ block_shape = (BLOCK_M , HEAD_DIM ),
283
+ order = (1 , 0 ),
284
+ )
285
+
286
+ # block table for the current sequence
287
+ block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
288
+ # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq)
289
+ # Consider `block_start_m` as the logical block idx in the current block table,
290
+ # as we have BLOCK_M the same size as the block size.
291
+ cur_block_table_idx = block_start_m
292
+ cur_block_id = tl .load (block_table_ptr + cur_block_table_idx * stride_btb )
293
+ offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
294
+
295
+ offsets_m = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
296
+ offsets_n = tl .arange (0 , BLOCK_N )
297
+ m_i = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
298
+ l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
299
+ acc = tl .zeros ([BLOCK_M , HEAD_DIM ], dtype = tl .float32 )
300
+
301
+ if block_start_m * BLOCK_M >= cur_seq_len :
302
+ return
303
+
304
+ Q_i = tl .load (Q_block_ptr , boundary_check = (1 , 0 ))
305
+
306
+ for block_start_n in range (0 , (block_start_m + 1 ) * BLOCK_M , BLOCK_N ):
307
+ block_start_n = tl .multiple_of (block_start_n , BLOCK_N )
308
+
309
+ k = tl .load (K_block_ptr , boundary_check = (0 , 1 ))
310
+ S_ij = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
311
+ S_ij += tl .dot (Q_i , k )
312
+ S_ij *= sm_scale
313
+ S_ij += tl .where (offsets_m [:, None ] >= (block_start_n + offsets_n [None , :]), 0 , float ("-inf" ))
314
+
315
+ m_ij = tl .max (S_ij , 1 ) # rowmax(Sij)
316
+ m_ij = tl .maximum (m_i , m_ij ) # m_ij
317
+ S_ij -= m_ij [:, None ]
318
+ p_ij_hat = tl .exp (S_ij )
319
+ scale = tl .exp (m_i - m_ij )
320
+ l_ij = scale * l_i + tl .sum (p_ij_hat , 1 )
321
+ acc = acc * scale [:, None ]
322
+
323
+ v = tl .load (V_block_ptr , boundary_check = (1 , 0 ))
324
+ p_ij_hat = p_ij_hat .to (v .type .element_ty )
325
+
326
+ acc += tl .dot (p_ij_hat , v )
327
+ l_i = l_ij
328
+ m_i = m_ij
329
+ K_block_ptr = tl .advance (K_block_ptr , (0 , BLOCK_N ))
330
+ V_block_ptr = tl .advance (V_block_ptr , (BLOCK_N , 0 ))
331
+
332
+ acc = acc / l_i [:, None ]
333
+ tl .store (O_block_ptr , acc .to (O .type .element_ty ), boundary_check = (1 , 0 ))
334
+
335
+ if cur_head_idx % KV_GROUPS == 0 :
336
+ # Copy k to corresponding cache block
337
+ block_range = tl .arange (0 , BLOCK_SIZE )
338
+ X_range = tl .arange (0 , KCACHE_X )
339
+ # unroll the loop aggressively
340
+ for split_x in tl .static_range (HEAD_DIM // KCACHE_X ):
341
+ offsets_dmodel_x_partion = tl .arange (split_x * KCACHE_X , (split_x + 1 ) * KCACHE_X )
342
+ offsets_k = K + offset_kv + offsets_dmodel_x_partion [None , :] * stride_kd + offsets_m [:, None ] * stride_kt
343
+ k = tl .load (offsets_k , mask = offsets_m [:, None ] < cur_seq_len , other = 0.0 )
344
+ # HACK: KCache must be contiguous in order to apply the following offsets calculation
345
+ offsets_kcache = (
346
+ KCache
347
+ + offset_kvcache
348
+ + split_x * BLOCK_SIZE * KCACHE_X
349
+ + block_range [:, None ] * KCACHE_X
350
+ + X_range [None , :]
351
+ )
352
+ tl .store (offsets_kcache , k , mask = block_range [:, None ] < cur_seq_len - block_start_m * BLOCK_SIZE )
353
+ # Copy v to corresponding cache block
354
+ offsets_vd = tl .arange (0 , HEAD_DIM ) # offsets_dmodel
355
+ offsets_vt = block_start_m * BLOCK_N + offsets_n
356
+ offsets_v = V + offset_kv + offsets_vt [None , :] * stride_vt + offsets_vd [:, None ] * stride_vd
357
+ v = tl .load (offsets_v , mask = offsets_vt [None , :] < cur_seq_len , other = 0.0 )
358
+ offsets_vcache = (
359
+ VCache + offset_kvcache + block_range [None , :] * stride_cachebs + offsets_vd [:, None ] * stride_cached
360
+ )
361
+ tl .store (offsets_vcache , v , mask = block_range [None , :] < cur_seq_len - block_start_m * BLOCK_SIZE )
362
+
363
+ return
364
+
365
+
188
366
# Triton 2.1.0
189
367
@triton .jit
190
368
def _alibi_fwd_context_paged_attention_kernel (
@@ -375,21 +553,33 @@ def context_attention_unpadded(
375
553
q : torch .Tensor , # [num_tokens, num_heads, head_dim]
376
554
k : torch .Tensor , # [num_tokens, num_kv_heads, head_dim]
377
555
v : torch .Tensor , # [num_tokens, num_kv_heads, head_dim]
378
- k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim, block_size ]
379
- v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim, block_size ]
556
+ k_cache : torch .Tensor , # [num_blocks, num_kv_heads, block_size, head_dim ]
557
+ v_cache : torch .Tensor , # [num_blocks, num_kv_heads, block_size, head_dim ]
380
558
context_lengths : torch .Tensor , # [num_seqs]
381
559
block_tables : torch .Tensor , # [num_seqs, max_blocks_per_sequence],
382
560
block_size : int ,
383
561
output : torch .Tensor = None , # [num_tokens, num_heads, head_dim]
384
562
alibi_slopes : torch .Tensor = None , # [num_heads]
385
563
max_seq_len : int = None ,
386
564
sm_scale : int = None ,
565
+ # NOTE(yuanheng-zhao): the following flag is used to determine whether to use the new layout for kcache
566
+ # [num_blocks, num_kv_heads, head_dim // x, block_size, x] - must be contiguous
567
+ use_new_kcache_layout : bool = False ,
387
568
):
388
569
Lq , Lk , Lv = q .shape [- 1 ], k .shape [- 1 ], v .shape [- 1 ]
389
570
assert Lq == Lk == Lv
390
571
assert Lk in {32 , 64 , 128 , 256 }
391
572
assert q .shape [0 ] == k .shape [0 ] == v .shape [0 ]
392
- assert k_cache .shape == v_cache .shape
573
+ k_cache_shape = k_cache .shape
574
+ v_cache_shape = v_cache .shape
575
+ if use_new_kcache_layout :
576
+ assert (
577
+ len (k_cache_shape ) == 5
578
+ and k_cache_shape [1 ] == v_cache_shape [1 ]
579
+ and k_cache_shape [2 ] * k_cache_shape [4 ] == v_cache_shape [3 ]
580
+ ), f"Invalid KCache shape { k_cache_shape } and VCache shape { v_cache_shape } "
581
+ else :
582
+ assert k_cache_shape == v_cache_shape , f"Invalid KCache shape { k_cache_shape } and VCache shape { v_cache_shape } "
393
583
assert context_lengths .shape [0 ] == block_tables .shape [0 ]
394
584
395
585
num_tokens , num_heads , head_dim = q .shape
@@ -413,6 +603,53 @@ def context_attention_unpadded(
413
603
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
414
604
grid = (triton .next_power_of_2 (num_seqs ), num_heads , triton .cdiv (max_seq_len , BLOCK_M ))
415
605
606
+ if use_new_kcache_layout :
607
+ # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
608
+ # the code (alibi kernel) will be refactored later to avoid code duplication, when
609
+ # the whole triton flow with new k cache layout has been supported and tested.
610
+ assert (
611
+ alibi_slopes is None
612
+ ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
613
+ x = k_cache_shape [4 ] # Intuition: 16 // dtype_size
614
+
615
+ _fwd_context_paged_attention_kernel_v2 [grid ](
616
+ q ,
617
+ k ,
618
+ v ,
619
+ output ,
620
+ k_cache ,
621
+ v_cache ,
622
+ block_tables ,
623
+ num_seqs ,
624
+ q .stride (0 ),
625
+ q .stride (1 ),
626
+ q .stride (2 ),
627
+ k .stride (0 ),
628
+ k .stride (1 ),
629
+ k .stride (2 ),
630
+ v .stride (0 ),
631
+ v .stride (1 ),
632
+ v .stride (2 ),
633
+ output .stride (0 ),
634
+ head_dim ,
635
+ 1 ,
636
+ v_cache .stride (0 ),
637
+ v_cache .stride (1 ),
638
+ v_cache .stride (2 ),
639
+ v_cache .stride (3 ),
640
+ block_tables .stride (0 ),
641
+ block_tables .stride (1 ),
642
+ context_lengths ,
643
+ sm_scale ,
644
+ KV_GROUPS = num_kv_group ,
645
+ BLOCK_SIZE = block_size ,
646
+ HEAD_DIM = Lk ,
647
+ KCACHE_X = x ,
648
+ BLOCK_M = BLOCK_M ,
649
+ BLOCK_N = BLOCK_N ,
650
+ )
651
+ return output
652
+
416
653
if alibi_slopes is not None :
417
654
_alibi_fwd_context_paged_attention_kernel [grid ](
418
655
q ,
0 commit comments