@@ -123,29 +123,9 @@ class FlashInferMetadata:
123
123
124
124
num_actual_tokens : int # Number of tokens excluding padding.
125
125
126
- # (batch_size + 1,). The cumulative subquery lengths of the sequences in
127
- # the batch, used to index into subquery. E.g., if the subquery length
128
- # is [4, 6], it is [0, 4, 10].
129
- qo_indptr_cpu : torch .Tensor
130
- # An example for paged_kv_indices, paged_kv_indptr:
131
- # request 1, page indices [0, 5, 8]
132
- # request 2, page indices [1, 6, 7]
133
- # request 3, page indices [3, 4]
134
- # paged_kv_indices is a concatenation of page indices of all requests:
135
- # [0, 5, 8, 1, 6, 7, 3, 4]
136
- # paged_kv_indptr is used to index into paged_kv_indices:
137
- # [0, 3, 6, 8]
138
- # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
139
- paged_kv_indptr_cpu : torch .Tensor
140
- # The page indices of the paged kv cache (on device for plan)
141
- paged_kv_indices : torch .Tensor
142
- # The number of entries in the last page of each request in
143
- # the paged kv cache, shape: [batch_size] (CPU for plan)
144
- paged_kv_last_page_len_cpu : torch .Tensor
145
126
# The data type of the query
146
127
q_data_type : torch .dtype
147
128
148
- seq_lens_cpu : torch .Tensor
149
129
slot_mapping : torch .Tensor
150
130
151
131
# For flashinfer trtllm batch decode
@@ -164,10 +144,6 @@ class FlashInferMetadata:
164
144
165
145
# For cascade attention (CPU for planning).
166
146
use_cascade : bool
167
- shared_qo_indptr_cpu : Optional [torch .Tensor ] = None
168
- shared_kv_page_indptr_cpu : Optional [torch .Tensor ] = None
169
- shared_kv_page_indices_cpu : Optional [torch .Tensor ] = None
170
- shared_kv_last_page_len_cpu : Optional [torch .Tensor ] = None
171
147
172
148
prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
173
149
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
@@ -327,134 +303,6 @@ def _get_cascade_wrapper(self):
327
303
2 , self ._get_workspace_buffer (), get_kv_cache_layout ())
328
304
return self ._cascade_wrapper
329
305
330
- def _plan (self , attn_metadata : FlashInferMetadata ):
331
- if attn_metadata .use_cascade :
332
- attn_metadata .cascade_wrapper = self ._get_cascade_wrapper ()
333
- attn_metadata .cascade_wrapper .plan (
334
- [
335
- attn_metadata .shared_qo_indptr_cpu ,
336
- attn_metadata .qo_indptr_cpu
337
- ],
338
- [
339
- attn_metadata .shared_kv_page_indptr_cpu ,
340
- attn_metadata .paged_kv_indptr_cpu
341
- ],
342
- [
343
- attn_metadata .shared_kv_page_indices_cpu ,
344
- attn_metadata .paged_kv_indices
345
- ],
346
- [
347
- attn_metadata .shared_kv_last_page_len_cpu ,
348
- attn_metadata .paged_kv_last_page_len_cpu
349
- ],
350
- self .num_qo_heads ,
351
- self .num_kv_heads ,
352
- self .head_dim ,
353
- self .page_size ,
354
- causal = True ,
355
- sm_scale = self .global_hyperparameters .sm_scale ,
356
- window_left = self .global_hyperparameters .window_left ,
357
- logits_soft_cap = self .global_hyperparameters .logits_soft_cap ,
358
- q_data_type = self .q_data_type ,
359
- kv_data_type = self .kv_cache_dtype ,
360
- )
361
- else :
362
- # Regular attention (common case).
363
- # Decodes are at the front and prefills are at the back,
364
- # according to reorder_batch()
365
- num_prefills = attn_metadata .num_prefills
366
- num_decodes = attn_metadata .num_decodes
367
- if num_prefills > 0 :
368
- # Decodes are first so prefills start after the last decode
369
- prefill_start = num_decodes
370
- attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
371
- assert attn_metadata .qo_indptr_cpu [prefill_start :].shape [
372
- 0 ] == num_prefills + 1
373
- assert attn_metadata .paged_kv_indptr_cpu [prefill_start :].shape [
374
- 0 ] == num_prefills + 1
375
- assert attn_metadata .paged_kv_last_page_len_cpu [
376
- prefill_start :].shape [0 ] == num_prefills
377
- # Since prefill_wrapper.run() will be called with
378
- # query[num_decode_tokens:] we need to adjust the qo_indptr
379
- # to be relative to the start of the prefill queries.
380
- qo_indptr_cpu = attn_metadata .qo_indptr_cpu [
381
- prefill_start :] - attn_metadata .qo_indptr_cpu [prefill_start ]
382
- paged_kv_indptr_cpu = attn_metadata .paged_kv_indptr_cpu [
383
- prefill_start :]
384
- if not attn_metadata .prefill_use_trtllm :
385
- attn_metadata .prefill_wrapper .plan (
386
- qo_indptr_cpu ,
387
- paged_kv_indptr_cpu ,
388
- attn_metadata .paged_kv_indices ,
389
- attn_metadata .
390
- paged_kv_last_page_len_cpu [prefill_start :],
391
- self .num_qo_heads ,
392
- self .num_kv_heads ,
393
- self .head_dim ,
394
- self .page_size ,
395
- causal = True ,
396
- sm_scale = self .global_hyperparameters .sm_scale ,
397
- window_left = self .global_hyperparameters .window_left ,
398
- logits_soft_cap = self .global_hyperparameters .
399
- logits_soft_cap ,
400
- q_data_type = self .q_data_type ,
401
- kv_data_type = self .kv_cache_dtype ,
402
- )
403
- else :
404
- attn_metadata .qo_indptr_gpu = qo_indptr_cpu .to (self .device )
405
- attn_metadata .paged_kv_indptr_gpu = paged_kv_indptr_cpu .to (
406
- self .device )
407
-
408
- if num_decodes > 0 :
409
- pure_decode = num_prefills == 0
410
- # possible required padding for cudagraph replay
411
- use_cudagraph = (self .enable_cuda_graph and pure_decode and
412
- num_decodes <= self ._decode_cudagraph_max_bs )
413
- if use_cudagraph :
414
- num_input_tokens = (
415
- self .vllm_config .pad_for_cudagraph (num_decodes ))
416
- # Carefully fulfill the padding region with reasonable value
417
- # on cpu.
418
- # Make sure paged_kv_indptr_cpu is not decreasing
419
- self .paged_kv_indptr_cpu [1 + num_decodes :1 +
420
- num_input_tokens ].fill_ (
421
- attn_metadata .
422
- paged_kv_indptr_cpu [- 1 ])
423
- # Fill the remaining paged_kv_last_page_len_cpu with 1.
424
- # This is because flashinfer treats 0 as a full page
425
- # instead of empty.
426
- self .paged_kv_last_page_len_cpu [
427
- num_decodes :num_input_tokens ].fill_ (1 )
428
-
429
- else :
430
- num_input_tokens = num_decodes
431
-
432
- attn_metadata .decode_wrapper = self ._get_decode_wrapper (
433
- num_input_tokens , use_cudagraph )
434
- if not attn_metadata .decode_use_trtllm :
435
- # Use the persistent buffer with padding length,
436
- # instead of the same address but chunked version
437
- # in atten_metadata when using cudagraph.
438
- fast_plan_decode (
439
- attn_metadata .decode_wrapper ,
440
- self .paged_kv_indptr_cpu [:num_input_tokens + 1 ],
441
- attn_metadata .paged_kv_indices ,
442
- self .paged_kv_last_page_len_cpu [:num_input_tokens ],
443
- attn_metadata .seq_lens_cpu [:num_input_tokens ],
444
- self .num_qo_heads ,
445
- self .num_kv_heads ,
446
- self .head_dim ,
447
- self .page_size ,
448
- # Disable flashinfer's pos encoding and use vllm's rope.
449
- pos_encoding_mode = "NONE" ,
450
- sm_scale = self .global_hyperparameters .sm_scale ,
451
- window_left = self .global_hyperparameters .window_left ,
452
- logits_soft_cap = self .global_hyperparameters .
453
- logits_soft_cap ,
454
- q_data_type = self .q_data_type ,
455
- kv_data_type = self .kv_cache_dtype ,
456
- )
457
-
458
306
def build (self ,
459
307
common_prefix_len : int ,
460
308
common_attn_metadata : CommonAttentionMetadata ,
@@ -548,13 +396,7 @@ def build(self,
548
396
549
397
attn_metadata = FlashInferMetadata (
550
398
num_actual_tokens = num_actual_tokens ,
551
- qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu ,
552
- paged_kv_indptr_cpu = self .paged_kv_indptr_cpu [:1 + num_reqs ],
553
- paged_kv_indices = paged_kv_indices ,
554
- paged_kv_last_page_len_cpu = self .
555
- paged_kv_last_page_len_cpu [:num_reqs ],
556
399
q_data_type = self .q_data_type ,
557
- seq_lens_cpu = seq_lens_cpu ,
558
400
slot_mapping = common_attn_metadata .slot_mapping ,
559
401
max_q_len = max_q_len ,
560
402
max_seq_len = max_seq_len ,
@@ -567,14 +409,123 @@ def build(self,
567
409
num_prefills = num_prefills ,
568
410
num_prefill_tokens = num_prefill_tokens ,
569
411
use_cascade = use_cascade ,
570
- shared_qo_indptr_cpu = shared_qo_indptr_cpu ,
571
- shared_kv_page_indptr_cpu = shared_kv_page_indptr_cpu ,
572
- shared_kv_page_indices_cpu = shared_kv_page_indices_cpu ,
573
- shared_kv_last_page_len_cpu = shared_kv_last_page_len_cpu ,
574
412
)
575
413
576
- self ._plan (attn_metadata )
414
+ qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu
415
+ paged_kv_indptr_cpu = self .paged_kv_indptr_cpu [:1 + num_reqs ]
416
+ paged_kv_last_page_len_cpu = self .paged_kv_last_page_len_cpu [:num_reqs ]
577
417
418
+ if attn_metadata .use_cascade :
419
+ attn_metadata .cascade_wrapper = self ._get_cascade_wrapper ()
420
+ attn_metadata .cascade_wrapper .plan (
421
+ [shared_qo_indptr_cpu , qo_indptr_cpu ],
422
+ [shared_kv_page_indptr_cpu , paged_kv_indptr_cpu ],
423
+ [shared_kv_page_indices_cpu , paged_kv_indices ],
424
+ [shared_kv_last_page_len_cpu , paged_kv_last_page_len_cpu ],
425
+ self .num_qo_heads ,
426
+ self .num_kv_heads ,
427
+ self .head_dim ,
428
+ self .page_size ,
429
+ causal = True ,
430
+ sm_scale = self .global_hyperparameters .sm_scale ,
431
+ window_left = self .global_hyperparameters .window_left ,
432
+ logits_soft_cap = self .global_hyperparameters .logits_soft_cap ,
433
+ q_data_type = self .q_data_type ,
434
+ kv_data_type = self .kv_cache_dtype ,
435
+ )
436
+ else :
437
+ # Regular attention (common case).
438
+ # Decodes are at the front and prefills are at the back,
439
+ # according to reorder_batch()
440
+ num_prefills = attn_metadata .num_prefills
441
+ num_decodes = attn_metadata .num_decodes
442
+ if num_prefills > 0 :
443
+ # Decodes are first so prefills start after the last decode
444
+ prefill_start = num_decodes
445
+ attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
446
+ assert qo_indptr_cpu [prefill_start :].shape [
447
+ 0 ] == num_prefills + 1
448
+ assert paged_kv_indptr_cpu [prefill_start :].shape [
449
+ 0 ] == num_prefills + 1
450
+ assert paged_kv_last_page_len_cpu [prefill_start :].shape [
451
+ 0 ] == num_prefills
452
+ # Since prefill_wrapper.run() will be called with
453
+ # query[num_decode_tokens:] we need to adjust the qo_indptr
454
+ # to be relative to the start of the prefill queries.
455
+ qo_indptr_cpu = qo_indptr_cpu [prefill_start :] - qo_indptr_cpu [
456
+ prefill_start ]
457
+ paged_kv_indptr_cpu = paged_kv_indptr_cpu [prefill_start :]
458
+ if not attn_metadata .prefill_use_trtllm :
459
+ attn_metadata .prefill_wrapper .plan (
460
+ qo_indptr_cpu ,
461
+ paged_kv_indptr_cpu ,
462
+ paged_kv_indices ,
463
+ paged_kv_last_page_len_cpu [prefill_start :],
464
+ self .num_qo_heads ,
465
+ self .num_kv_heads ,
466
+ self .head_dim ,
467
+ self .page_size ,
468
+ causal = True ,
469
+ sm_scale = self .global_hyperparameters .sm_scale ,
470
+ window_left = self .global_hyperparameters .window_left ,
471
+ logits_soft_cap = self .global_hyperparameters .
472
+ logits_soft_cap ,
473
+ q_data_type = self .q_data_type ,
474
+ kv_data_type = self .kv_cache_dtype ,
475
+ )
476
+ else :
477
+ attn_metadata .qo_indptr_gpu = qo_indptr_cpu .to (self .device )
478
+ attn_metadata .paged_kv_indptr_gpu = paged_kv_indptr_cpu .to (
479
+ self .device )
480
+
481
+ if num_decodes > 0 :
482
+ pure_decode = num_prefills == 0
483
+ # possible required padding for cudagraph replay
484
+ use_cudagraph = (self .enable_cuda_graph and pure_decode and
485
+ num_decodes <= self ._decode_cudagraph_max_bs )
486
+ if use_cudagraph :
487
+ num_input_tokens = (
488
+ self .vllm_config .pad_for_cudagraph (num_decodes ))
489
+ # Carefully fulfill the padding region with reasonable value
490
+ # on cpu.
491
+ # Make sure paged_kv_indptr_cpu is not decreasing
492
+ self .paged_kv_indptr_cpu [1 + num_decodes :1 +
493
+ num_input_tokens ].fill_ (
494
+ paged_kv_indptr_cpu [- 1 ])
495
+ # Fill the remaining paged_kv_last_page_len_cpu with 1.
496
+ # This is because flashinfer treats 0 as a full page
497
+ # instead of empty.
498
+ self .paged_kv_last_page_len_cpu [
499
+ num_decodes :num_input_tokens ].fill_ (1 )
500
+
501
+ else :
502
+ num_input_tokens = num_decodes
503
+
504
+ attn_metadata .decode_wrapper = self ._get_decode_wrapper (
505
+ num_input_tokens , use_cudagraph )
506
+ if not attn_metadata .decode_use_trtllm :
507
+ # Use the persistent buffer with padding length,
508
+ # instead of the same address but chunked version
509
+ # in atten_metadata when using cudagraph.
510
+ fast_plan_decode (
511
+ attn_metadata .decode_wrapper ,
512
+ self .paged_kv_indptr_cpu [:num_input_tokens + 1 ],
513
+ paged_kv_indices ,
514
+ self .paged_kv_last_page_len_cpu [:num_input_tokens ],
515
+ seq_lens_cpu [:num_input_tokens ],
516
+ self .num_qo_heads ,
517
+ self .num_kv_heads ,
518
+ self .head_dim ,
519
+ self .page_size ,
520
+ # Disable flashinfer's pos encoding and use vllm's rope.
521
+ pos_encoding_mode = "NONE" ,
522
+ sm_scale = self .global_hyperparameters .sm_scale ,
523
+ window_left = self .global_hyperparameters .window_left ,
524
+ logits_soft_cap = self .global_hyperparameters .
525
+ logits_soft_cap ,
526
+ q_data_type = self .q_data_type ,
527
+ kv_data_type = self .kv_cache_dtype ,
528
+ )
578
529
return attn_metadata
579
530
580
531
def build_for_cudagraph_capture (
0 commit comments