Skip to content

Commit ec2795f

Browse files
WoosukKwonepwalsh
authored andcommitted
[Misc] Simplify FlashInfer attention metadata (vllm-project#23585)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent e1a317a commit ec2795f

File tree

1 file changed

+114
-163
lines changed

1 file changed

+114
-163
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 114 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,9 @@ class FlashInferMetadata:
123123

124124
num_actual_tokens: int # Number of tokens excluding padding.
125125

126-
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
127-
# the batch, used to index into subquery. E.g., if the subquery length
128-
# is [4, 6], it is [0, 4, 10].
129-
qo_indptr_cpu: torch.Tensor
130-
# An example for paged_kv_indices, paged_kv_indptr:
131-
# request 1, page indices [0, 5, 8]
132-
# request 2, page indices [1, 6, 7]
133-
# request 3, page indices [3, 4]
134-
# paged_kv_indices is a concatenation of page indices of all requests:
135-
# [0, 5, 8, 1, 6, 7, 3, 4]
136-
# paged_kv_indptr is used to index into paged_kv_indices:
137-
# [0, 3, 6, 8]
138-
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
139-
paged_kv_indptr_cpu: torch.Tensor
140-
# The page indices of the paged kv cache (on device for plan)
141-
paged_kv_indices: torch.Tensor
142-
# The number of entries in the last page of each request in
143-
# the paged kv cache, shape: [batch_size] (CPU for plan)
144-
paged_kv_last_page_len_cpu: torch.Tensor
145126
# The data type of the query
146127
q_data_type: torch.dtype
147128

148-
seq_lens_cpu: torch.Tensor
149129
slot_mapping: torch.Tensor
150130

151131
# For flashinfer trtllm batch decode
@@ -164,10 +144,6 @@ class FlashInferMetadata:
164144

165145
# For cascade attention (CPU for planning).
166146
use_cascade: bool
167-
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
168-
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
169-
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
170-
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
171147

172148
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
173149
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
@@ -327,134 +303,6 @@ def _get_cascade_wrapper(self):
327303
2, self._get_workspace_buffer(), get_kv_cache_layout())
328304
return self._cascade_wrapper
329305

330-
def _plan(self, attn_metadata: FlashInferMetadata):
331-
if attn_metadata.use_cascade:
332-
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
333-
attn_metadata.cascade_wrapper.plan(
334-
[
335-
attn_metadata.shared_qo_indptr_cpu,
336-
attn_metadata.qo_indptr_cpu
337-
],
338-
[
339-
attn_metadata.shared_kv_page_indptr_cpu,
340-
attn_metadata.paged_kv_indptr_cpu
341-
],
342-
[
343-
attn_metadata.shared_kv_page_indices_cpu,
344-
attn_metadata.paged_kv_indices
345-
],
346-
[
347-
attn_metadata.shared_kv_last_page_len_cpu,
348-
attn_metadata.paged_kv_last_page_len_cpu
349-
],
350-
self.num_qo_heads,
351-
self.num_kv_heads,
352-
self.head_dim,
353-
self.page_size,
354-
causal=True,
355-
sm_scale=self.global_hyperparameters.sm_scale,
356-
window_left=self.global_hyperparameters.window_left,
357-
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
358-
q_data_type=self.q_data_type,
359-
kv_data_type=self.kv_cache_dtype,
360-
)
361-
else:
362-
# Regular attention (common case).
363-
# Decodes are at the front and prefills are at the back,
364-
# according to reorder_batch()
365-
num_prefills = attn_metadata.num_prefills
366-
num_decodes = attn_metadata.num_decodes
367-
if num_prefills > 0:
368-
# Decodes are first so prefills start after the last decode
369-
prefill_start = num_decodes
370-
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
371-
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
372-
0] == num_prefills + 1
373-
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
374-
0] == num_prefills + 1
375-
assert attn_metadata.paged_kv_last_page_len_cpu[
376-
prefill_start:].shape[0] == num_prefills
377-
# Since prefill_wrapper.run() will be called with
378-
# query[num_decode_tokens:] we need to adjust the qo_indptr
379-
# to be relative to the start of the prefill queries.
380-
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
381-
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
382-
paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[
383-
prefill_start:]
384-
if not attn_metadata.prefill_use_trtllm:
385-
attn_metadata.prefill_wrapper.plan(
386-
qo_indptr_cpu,
387-
paged_kv_indptr_cpu,
388-
attn_metadata.paged_kv_indices,
389-
attn_metadata.
390-
paged_kv_last_page_len_cpu[prefill_start:],
391-
self.num_qo_heads,
392-
self.num_kv_heads,
393-
self.head_dim,
394-
self.page_size,
395-
causal=True,
396-
sm_scale=self.global_hyperparameters.sm_scale,
397-
window_left=self.global_hyperparameters.window_left,
398-
logits_soft_cap=self.global_hyperparameters.
399-
logits_soft_cap,
400-
q_data_type=self.q_data_type,
401-
kv_data_type=self.kv_cache_dtype,
402-
)
403-
else:
404-
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
405-
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
406-
self.device)
407-
408-
if num_decodes > 0:
409-
pure_decode = num_prefills == 0
410-
# possible required padding for cudagraph replay
411-
use_cudagraph = (self.enable_cuda_graph and pure_decode and
412-
num_decodes <= self._decode_cudagraph_max_bs)
413-
if use_cudagraph:
414-
num_input_tokens = (
415-
self.vllm_config.pad_for_cudagraph(num_decodes))
416-
# Carefully fulfill the padding region with reasonable value
417-
# on cpu.
418-
# Make sure paged_kv_indptr_cpu is not decreasing
419-
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
420-
num_input_tokens].fill_(
421-
attn_metadata.
422-
paged_kv_indptr_cpu[-1])
423-
# Fill the remaining paged_kv_last_page_len_cpu with 1.
424-
# This is because flashinfer treats 0 as a full page
425-
# instead of empty.
426-
self.paged_kv_last_page_len_cpu[
427-
num_decodes:num_input_tokens].fill_(1)
428-
429-
else:
430-
num_input_tokens = num_decodes
431-
432-
attn_metadata.decode_wrapper = self._get_decode_wrapper(
433-
num_input_tokens, use_cudagraph)
434-
if not attn_metadata.decode_use_trtllm:
435-
# Use the persistent buffer with padding length,
436-
# instead of the same address but chunked version
437-
# in atten_metadata when using cudagraph.
438-
fast_plan_decode(
439-
attn_metadata.decode_wrapper,
440-
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
441-
attn_metadata.paged_kv_indices,
442-
self.paged_kv_last_page_len_cpu[:num_input_tokens],
443-
attn_metadata.seq_lens_cpu[:num_input_tokens],
444-
self.num_qo_heads,
445-
self.num_kv_heads,
446-
self.head_dim,
447-
self.page_size,
448-
# Disable flashinfer's pos encoding and use vllm's rope.
449-
pos_encoding_mode="NONE",
450-
sm_scale=self.global_hyperparameters.sm_scale,
451-
window_left=self.global_hyperparameters.window_left,
452-
logits_soft_cap=self.global_hyperparameters.
453-
logits_soft_cap,
454-
q_data_type=self.q_data_type,
455-
kv_data_type=self.kv_cache_dtype,
456-
)
457-
458306
def build(self,
459307
common_prefix_len: int,
460308
common_attn_metadata: CommonAttentionMetadata,
@@ -548,13 +396,7 @@ def build(self,
548396

549397
attn_metadata = FlashInferMetadata(
550398
num_actual_tokens=num_actual_tokens,
551-
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
552-
paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs],
553-
paged_kv_indices=paged_kv_indices,
554-
paged_kv_last_page_len_cpu=self.
555-
paged_kv_last_page_len_cpu[:num_reqs],
556399
q_data_type=self.q_data_type,
557-
seq_lens_cpu=seq_lens_cpu,
558400
slot_mapping=common_attn_metadata.slot_mapping,
559401
max_q_len=max_q_len,
560402
max_seq_len=max_seq_len,
@@ -567,14 +409,123 @@ def build(self,
567409
num_prefills=num_prefills,
568410
num_prefill_tokens=num_prefill_tokens,
569411
use_cascade=use_cascade,
570-
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
571-
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
572-
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
573-
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
574412
)
575413

576-
self._plan(attn_metadata)
414+
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
415+
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
416+
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
577417

418+
if attn_metadata.use_cascade:
419+
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
420+
attn_metadata.cascade_wrapper.plan(
421+
[shared_qo_indptr_cpu, qo_indptr_cpu],
422+
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
423+
[shared_kv_page_indices_cpu, paged_kv_indices],
424+
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
425+
self.num_qo_heads,
426+
self.num_kv_heads,
427+
self.head_dim,
428+
self.page_size,
429+
causal=True,
430+
sm_scale=self.global_hyperparameters.sm_scale,
431+
window_left=self.global_hyperparameters.window_left,
432+
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
433+
q_data_type=self.q_data_type,
434+
kv_data_type=self.kv_cache_dtype,
435+
)
436+
else:
437+
# Regular attention (common case).
438+
# Decodes are at the front and prefills are at the back,
439+
# according to reorder_batch()
440+
num_prefills = attn_metadata.num_prefills
441+
num_decodes = attn_metadata.num_decodes
442+
if num_prefills > 0:
443+
# Decodes are first so prefills start after the last decode
444+
prefill_start = num_decodes
445+
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
446+
assert qo_indptr_cpu[prefill_start:].shape[
447+
0] == num_prefills + 1
448+
assert paged_kv_indptr_cpu[prefill_start:].shape[
449+
0] == num_prefills + 1
450+
assert paged_kv_last_page_len_cpu[prefill_start:].shape[
451+
0] == num_prefills
452+
# Since prefill_wrapper.run() will be called with
453+
# query[num_decode_tokens:] we need to adjust the qo_indptr
454+
# to be relative to the start of the prefill queries.
455+
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
456+
prefill_start]
457+
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
458+
if not attn_metadata.prefill_use_trtllm:
459+
attn_metadata.prefill_wrapper.plan(
460+
qo_indptr_cpu,
461+
paged_kv_indptr_cpu,
462+
paged_kv_indices,
463+
paged_kv_last_page_len_cpu[prefill_start:],
464+
self.num_qo_heads,
465+
self.num_kv_heads,
466+
self.head_dim,
467+
self.page_size,
468+
causal=True,
469+
sm_scale=self.global_hyperparameters.sm_scale,
470+
window_left=self.global_hyperparameters.window_left,
471+
logits_soft_cap=self.global_hyperparameters.
472+
logits_soft_cap,
473+
q_data_type=self.q_data_type,
474+
kv_data_type=self.kv_cache_dtype,
475+
)
476+
else:
477+
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
478+
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
479+
self.device)
480+
481+
if num_decodes > 0:
482+
pure_decode = num_prefills == 0
483+
# possible required padding for cudagraph replay
484+
use_cudagraph = (self.enable_cuda_graph and pure_decode and
485+
num_decodes <= self._decode_cudagraph_max_bs)
486+
if use_cudagraph:
487+
num_input_tokens = (
488+
self.vllm_config.pad_for_cudagraph(num_decodes))
489+
# Carefully fulfill the padding region with reasonable value
490+
# on cpu.
491+
# Make sure paged_kv_indptr_cpu is not decreasing
492+
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
493+
num_input_tokens].fill_(
494+
paged_kv_indptr_cpu[-1])
495+
# Fill the remaining paged_kv_last_page_len_cpu with 1.
496+
# This is because flashinfer treats 0 as a full page
497+
# instead of empty.
498+
self.paged_kv_last_page_len_cpu[
499+
num_decodes:num_input_tokens].fill_(1)
500+
501+
else:
502+
num_input_tokens = num_decodes
503+
504+
attn_metadata.decode_wrapper = self._get_decode_wrapper(
505+
num_input_tokens, use_cudagraph)
506+
if not attn_metadata.decode_use_trtllm:
507+
# Use the persistent buffer with padding length,
508+
# instead of the same address but chunked version
509+
# in atten_metadata when using cudagraph.
510+
fast_plan_decode(
511+
attn_metadata.decode_wrapper,
512+
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
513+
paged_kv_indices,
514+
self.paged_kv_last_page_len_cpu[:num_input_tokens],
515+
seq_lens_cpu[:num_input_tokens],
516+
self.num_qo_heads,
517+
self.num_kv_heads,
518+
self.head_dim,
519+
self.page_size,
520+
# Disable flashinfer's pos encoding and use vllm's rope.
521+
pos_encoding_mode="NONE",
522+
sm_scale=self.global_hyperparameters.sm_scale,
523+
window_left=self.global_hyperparameters.window_left,
524+
logits_soft_cap=self.global_hyperparameters.
525+
logits_soft_cap,
526+
q_data_type=self.q_data_type,
527+
kv_data_type=self.kv_cache_dtype,
528+
)
578529
return attn_metadata
579530

580531
def build_for_cudagraph_capture(

0 commit comments

Comments
 (0)