Skip to content

Commit bf5cec9

Browse files
committed
Cherry pick vllm-project#19158
1 parent 82acc9c commit bf5cec9

File tree

4 files changed

+330
-178
lines changed

4 files changed

+330
-178
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,6 @@ def unified_attention(
711711
scale=softmax_scale,
712712
k_scale=k_descale,
713713
v_scale=v_descale,
714-
out_scale=output_scale,
715714
softcap=softcap,
716715
num_query_heads=num_query_heads,
717716
num_queries_per_kv=num_queries_per_kv,
@@ -737,7 +736,6 @@ def unified_attention(
737736
num_seqs=num_seqs,
738737
BLOCK_M=BLOCK_M,
739738
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
740-
USE_FP8=output_scale is not None,
741739
)
742740

743741
reduce_segments[(q.shape[0], num_query_heads)](

vllm/v1/attention/backends/flash_attn.py

Lines changed: 3 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from vllm.logger import init_logger
2020
from vllm.platforms import current_platform
2121
from vllm.utils import cdiv
22-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
23-
CommonAttentionMetadata,
24-
get_kv_cache_layout)
22+
from vllm.v1.attention.backends.utils import (
23+
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
24+
make_local_attention_virtual_batches)
2525
from vllm.v1.kv_cache_interface import AttentionSpec
2626
from vllm.v1.worker.block_table import BlockTable
2727

@@ -126,172 +126,6 @@ class LocalAttentionMetadata:
126126
local_attn_metadata: Optional[LocalAttentionMetadata] = None
127127

128128

129-
#
130-
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
131-
# local attention blocks, where each block is passed to the attention kernel
132-
# as an independent local ("virtual") batch item.
133-
#
134-
# For example, if are performing a chunked prefill a batch of 3 sequences:
135-
# q_seqlens = [4, 10, 5]
136-
# kv_seqlens = [6, 17, 9]
137-
# Then normally for regular attention we would compute with an attention mask
138-
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
139-
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
140-
# k_toks > 0 1 2 3 4 5
141-
# q_toks v _____________
142-
# 0 | 1 1 1
143-
# 1 | 1 1 1 1
144-
# 2 | 1 1 1 1 1
145-
# 3 | 1 1 1 1 1 1
146-
#
147-
# for local attention (with attn_chunk_size = 4) we would compute with an
148-
# attention mask like:
149-
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
150-
# k_toks > 0 1 2 3 4 5
151-
# q_toks v _____________
152-
# 0 | 1 1 1
153-
# 1 | 1 1 1 1
154-
# 2 | 1
155-
# 3 | 1 1
156-
#
157-
# We can simulate this mask using standard flash-attention by breaking the
158-
# sequences into local ("virtual") batches, where each local batch item is a
159-
# local attention block, so in this case batch idx 0 would be broken up into:
160-
#
161-
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
162-
# k_toks > 0 1 2 3
163-
# q_toks v _____________
164-
# 0 | 1 1 1
165-
# 1 | 1 1 1 1
166-
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
167-
# k_toks > 4 5
168-
# q_toks v _____________
169-
# 2 | 1
170-
# 3 | 1 1
171-
#
172-
# e.g. if we have:
173-
# attn_chunk_size = 4
174-
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
175-
# Then this function would return:
176-
# __b0__ ______b1______ __b2__ < orig batch indices
177-
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
178-
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
179-
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
180-
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
181-
def make_local_attention_virtual_batches(
182-
attn_chunk_size: int,
183-
query_start_loc_np: np.ndarray,
184-
seq_lens_np: np.ndarray,
185-
block_table: torch.Tensor,
186-
block_size: int = 0,
187-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
188-
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
189-
actual_batch_size = seq_lens_np.shape[0]
190-
191-
# Handle if we are starting in the middle of a local attention block,
192-
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
193-
# the number of tokens that are not in the first local attention block and
194-
# then we can simply use a cdiv for the rest.
195-
# For example if we have:
196-
# attn_chunk_size = 4
197-
# q_seqlens = [4, 10, 5]
198-
# k_seqlens = [6, 17, 9]
199-
# Then we would get:
200-
# new_tokens_in_first_block = [2, 1, 4]
201-
# local_blocks = [2, 4, 2]
202-
q_tokens_in_first_block = np.minimum(
203-
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
204-
q_seqlens).astype(np.int32)
205-
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
206-
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
207-
attn_chunk_size)
208-
209-
# Once we know the number of local blocks we can compute the request spans
210-
# for each batch idx, we can figure out the number of "virtual" requests we
211-
# have to make,
212-
# For the above example we would get:
213-
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
214-
#
215-
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
216-
# (TODO: max a utility to share this code with _prepare_inputs)
217-
# arange step 1. [2, 4, 2] -> [2, 6, 8]
218-
cu_num_blocks = np.cumsum(local_blocks)
219-
virtual_batches = cu_num_blocks[-1]
220-
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
221-
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
222-
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
223-
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
224-
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
225-
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
226-
# Then we can compute the seqlens_q_local, handling the fact that the
227-
# first and last blocks could be partial
228-
seqlens_q_local = \
229-
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
230-
# set the first block since this may be a partial block
231-
seqlens_q_local[arange == 0] = q_tokens_in_first_block
232-
# set the remaining blocks
233-
seqlens_q_local[arange > 0] = np.minimum(
234-
seqlens_q_local - attn_chunk_size * (arange - 1),
235-
attn_chunk_size)[arange > 0]
236-
237-
# convert from q_seqlens to cu_seqlens_q
238-
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
239-
.astype(np.int32)
240-
241-
# compute the seqlens_k_local,
242-
# basically a full local attention block for all but the last block in each
243-
# batch
244-
# For our example this will be:
245-
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
246-
seqlens_k_local = np.full(cu_num_blocks[-1],
247-
attn_chunk_size,
248-
dtype=np.int32)
249-
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
250-
251-
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
252-
(rarange * attn_chunk_size + \
253-
np.repeat(tokens_in_last_block, local_blocks))
254-
# For the example the local attention blocks start at:
255-
# _b0_ _____b1_____ _b2_
256-
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
257-
block_starts = k_seqstarts_absolute // block_size
258-
assert attn_chunk_size % block_size == 0, \
259-
f"attn_chunk_size {attn_chunk_size} is not " \
260-
f"divisible by block_size {block_size}"
261-
pages_per_local_batch = attn_chunk_size // block_size
262-
263-
# Create a block_table for the local attention blocks
264-
# For out example if we have a block-table like (assuming block_size=2):
265-
# block_table = [
266-
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
267-
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
268-
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
269-
# ]
270-
# Then for the local batches we would want a block-table like
271-
# block_table_local = [
272-
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
273-
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
274-
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
275-
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
276-
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
277-
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
278-
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
279-
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
280-
# ]
281-
block_indices= np.broadcast_to(
282-
np.arange(pages_per_local_batch, dtype=np.int32),
283-
(virtual_batches, pages_per_local_batch)) \
284-
+ np.expand_dims(block_starts, axis=1)
285-
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
286-
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
287-
local_blocks * pages_per_local_batch)
288-
block_table_local = block_table[batch_indices, block_indices]\
289-
.view(virtual_batches, -1)
290-
291-
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
292-
block_table_local
293-
294-
295129
def _get_sliding_window_configs(
296130
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
297131
"""Get the set of all sliding window configs used in the model."""

vllm/v1/attention/backends/triton_attn.py

Lines changed: 159 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with PagedAttention and Triton prefix prefill."""
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
56

67
import torch
78

@@ -15,8 +16,10 @@
1516
from vllm.attention.ops.triton_unified_attention import unified_attention
1617
from vllm.logger import init_logger
1718
from vllm.platforms import current_platform
18-
from vllm.v1.attention.backends.flash_attn import (
19-
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
19+
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
20+
from vllm.v1.attention.backends.utils import (
21+
AttentionMetadataBuilder, CommonAttentionMetadata,
22+
make_local_attention_virtual_batches)
2023
from vllm.v1.kv_cache_interface import AttentionSpec
2124
from vllm.v1.worker.block_table import BlockTable
2225

@@ -26,12 +29,161 @@
2629
logger = init_logger(__name__)
2730

2831

29-
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
32+
@dataclass
33+
class TritonAttentionMetadata:
34+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
35+
# |---------- N-1 iteration --------|
36+
# |---------------- N iteration ---------------------|
37+
# |- tokenA -|......................|-- newTokens ---|
38+
# |---------- context_len ----------|
39+
# |-------------------- seq_len ---------------------|
40+
# |-- query_len ---|
41+
42+
num_actual_tokens: int # Number of tokens excluding padding.
43+
max_query_len: int
44+
query_start_loc: torch.Tensor
45+
max_seq_len: int
46+
seq_lens: torch.Tensor
47+
block_table: torch.Tensor
48+
slot_mapping: torch.Tensor
49+
50+
# For cascade attention.
51+
use_cascade: bool
52+
common_prefix_len: int
53+
cu_prefix_query_lens: Optional[torch.Tensor]
54+
prefix_kv_lens: Optional[torch.Tensor]
55+
suffix_kv_lens: Optional[torch.Tensor]
56+
57+
# Optional aot scheduling
58+
scheduler_metadata: Optional[torch.Tensor] = None
59+
prefix_scheduler_metadata: Optional[torch.Tensor] = None
60+
61+
# for local attention
62+
@dataclass
63+
class LocalAttentionMetadata:
64+
local_query_start_loc: torch.Tensor
65+
local_seqused_k: torch.Tensor
66+
local_block_table: torch.Tensor
67+
local_max_query_len: int
68+
local_max_seq_len: int
69+
local_scheduler_metadata: Optional[torch.Tensor]
70+
71+
local_attn_metadata: Optional[LocalAttentionMetadata] = None
72+
73+
74+
class TritonAttentionMetadataBuilder(
75+
AttentionMetadataBuilder[TritonAttentionMetadata]):
76+
full_cudagraph_supported: ClassVar[bool] = True
3077

3178
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
3279
block_table: BlockTable):
33-
super().__init__(runner, kv_cache_spec, block_table)
34-
self.aot_schedule = False
80+
self.runner = runner
81+
self.block_size = kv_cache_spec.block_size
82+
self.kv_cache_spec = kv_cache_spec
83+
self.block_table = block_table
84+
85+
def build_for_cudagraph_capture(
86+
self, common_attn_metadata: CommonAttentionMetadata
87+
) -> TritonAttentionMetadata:
88+
attn_metadata = self.build(0, common_attn_metadata)
89+
# When doing full graph capture, setting seq_lens to
90+
# max_model_len will cause graph capture to be extremely
91+
# slow, so here we set it to 1.
92+
attn_metadata.seq_lens.fill_(1)
93+
return attn_metadata
94+
95+
def build(
96+
self, common_prefix_len: int,
97+
common_attn_metadata: CommonAttentionMetadata
98+
) -> TritonAttentionMetadata:
99+
num_reqs = common_attn_metadata.num_reqs
100+
num_actual_tokens = common_attn_metadata.num_actual_tokens
101+
max_query_len = common_attn_metadata.max_query_len
102+
103+
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
104+
query_start_loc = common_attn_metadata.query_start_loc
105+
seq_lens = common_attn_metadata.seq_lens
106+
block_table = self.block_table
107+
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
108+
109+
block_table.slot_mapping[:num_actual_tokens].copy_(
110+
block_table.slot_mapping_cpu[:num_actual_tokens],
111+
non_blocking=True)
112+
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
113+
# mode.
114+
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
115+
116+
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
117+
118+
# for local attention
119+
local_attn_metadata = None
120+
if self.runner.attention_chunk_size is not None:
121+
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
122+
virt_block_table_tensor = make_local_attention_virtual_batches(
123+
self.runner.attention_chunk_size,
124+
self.runner.query_start_loc_np[:num_reqs + 1],
125+
self.runner.seq_lens_np[:num_reqs],
126+
block_table_tensor,
127+
self.block_size,
128+
)
129+
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
130+
self.runner.device, non_blocking=True)
131+
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
132+
self.runner.device, non_blocking=True)
133+
local_max_query_len = seqlens_q_local_np.max()
134+
local_max_seq_len = virt_k_seqlens_np.max()
135+
136+
local_attn_metadata = TritonAttentionMetadata \
137+
.LocalAttentionMetadata(
138+
local_query_start_loc=local_query_start_loc,
139+
local_seqused_k=local_seqused_k,
140+
local_block_table=virt_block_table_tensor,
141+
local_max_query_len=local_max_query_len,
142+
local_max_seq_len=local_max_seq_len,
143+
local_scheduler_metadata=None,
144+
)
145+
146+
use_cascade = common_prefix_len > 0
147+
148+
if use_cascade:
149+
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
150+
dtype=torch.int32,
151+
device=self.runner.device)
152+
prefix_kv_lens = torch.tensor([common_prefix_len],
153+
dtype=torch.int32,
154+
device=self.runner.device)
155+
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
156+
common_prefix_len)
157+
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
158+
self.runner.device)
159+
else:
160+
cu_prefix_query_lens = None
161+
prefix_kv_lens = None
162+
suffix_kv_lens = None
163+
prefix_scheduler_metadata = None
164+
165+
attn_metadata = TritonAttentionMetadata(
166+
num_actual_tokens=num_actual_tokens,
167+
max_query_len=max_query_len,
168+
query_start_loc=query_start_loc,
169+
max_seq_len=max_seq_len,
170+
seq_lens=seq_lens,
171+
block_table=block_table_tensor,
172+
slot_mapping=slot_mapping,
173+
use_cascade=use_cascade,
174+
common_prefix_len=common_prefix_len,
175+
cu_prefix_query_lens=cu_prefix_query_lens,
176+
prefix_kv_lens=prefix_kv_lens,
177+
suffix_kv_lens=suffix_kv_lens,
178+
local_attn_metadata=local_attn_metadata,
179+
prefix_scheduler_metadata=prefix_scheduler_metadata,
180+
)
181+
return attn_metadata
182+
183+
def can_run_in_cudagraph(
184+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
185+
# Full CUDA Graph always supported
186+
return True
35187

36188

37189
class TritonAttentionBackend(AttentionBackend):
@@ -52,7 +204,7 @@ def get_impl_cls() -> type["TritonAttentionImpl"]:
52204

53205
@staticmethod
54206
def get_metadata_cls() -> type["AttentionMetadata"]:
55-
return FlashAttentionMetadata
207+
return TritonAttentionMetadata
56208

57209
@staticmethod
58210
def get_kv_cache_shape(

0 commit comments

Comments
 (0)