|
19 | 19 | from vllm.logger import init_logger |
20 | 20 | from vllm.platforms import current_platform |
21 | 21 | from vllm.utils import cdiv |
22 | | -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, |
23 | | - CommonAttentionMetadata, |
24 | | - get_kv_cache_layout) |
| 22 | +from vllm.v1.attention.backends.utils import ( |
| 23 | + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, |
| 24 | + make_local_attention_virtual_batches) |
25 | 25 | from vllm.v1.kv_cache_interface import AttentionSpec |
26 | 26 | from vllm.v1.worker.block_table import BlockTable |
27 | 27 |
|
@@ -126,172 +126,6 @@ class LocalAttentionMetadata: |
126 | 126 | local_attn_metadata: Optional[LocalAttentionMetadata] = None |
127 | 127 |
|
128 | 128 |
|
129 | | -# |
130 | | -# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into |
131 | | -# local attention blocks, where each block is passed to the attention kernel |
132 | | -# as an independent local ("virtual") batch item. |
133 | | -# |
134 | | -# For example, if are performing a chunked prefill a batch of 3 sequences: |
135 | | -# q_seqlens = [4, 10, 5] |
136 | | -# kv_seqlens = [6, 17, 9] |
137 | | -# Then normally for regular attention we would compute with an attention mask |
138 | | -# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: |
139 | | -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) |
140 | | -# k_toks > 0 1 2 3 4 5 |
141 | | -# q_toks v _____________ |
142 | | -# 0 | 1 1 1 |
143 | | -# 1 | 1 1 1 1 |
144 | | -# 2 | 1 1 1 1 1 |
145 | | -# 3 | 1 1 1 1 1 1 |
146 | | -# |
147 | | -# for local attention (with attn_chunk_size = 4) we would compute with an |
148 | | -# attention mask like: |
149 | | -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) |
150 | | -# k_toks > 0 1 2 3 4 5 |
151 | | -# q_toks v _____________ |
152 | | -# 0 | 1 1 1 |
153 | | -# 1 | 1 1 1 1 |
154 | | -# 2 | 1 |
155 | | -# 3 | 1 1 |
156 | | -# |
157 | | -# We can simulate this mask using standard flash-attention by breaking the |
158 | | -# sequences into local ("virtual") batches, where each local batch item is a |
159 | | -# local attention block, so in this case batch idx 0 would be broken up into: |
160 | | -# |
161 | | -# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) |
162 | | -# k_toks > 0 1 2 3 |
163 | | -# q_toks v _____________ |
164 | | -# 0 | 1 1 1 |
165 | | -# 1 | 1 1 1 1 |
166 | | -# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) |
167 | | -# k_toks > 4 5 |
168 | | -# q_toks v _____________ |
169 | | -# 2 | 1 |
170 | | -# 3 | 1 1 |
171 | | -# |
172 | | -# e.g. if we have: |
173 | | -# attn_chunk_size = 4 |
174 | | -# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) |
175 | | -# Then this function would return: |
176 | | -# __b0__ ______b1______ __b2__ < orig batch indices |
177 | | -# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] |
178 | | -# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] |
179 | | -# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] |
180 | | -# block_table_local : shape[local_virtual_batches, pages_per_local_batch] |
181 | | -def make_local_attention_virtual_batches( |
182 | | - attn_chunk_size: int, |
183 | | - query_start_loc_np: np.ndarray, |
184 | | - seq_lens_np: np.ndarray, |
185 | | - block_table: torch.Tensor, |
186 | | - block_size: int = 0, |
187 | | -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: |
188 | | - q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] |
189 | | - actual_batch_size = seq_lens_np.shape[0] |
190 | | - |
191 | | - # Handle if we are starting in the middle of a local attention block, |
192 | | - # we assume q_seqlens > 0 (for all elements), for each batch idx we compute |
193 | | - # the number of tokens that are not in the first local attention block and |
194 | | - # then we can simply use a cdiv for the rest. |
195 | | - # For example if we have: |
196 | | - # attn_chunk_size = 4 |
197 | | - # q_seqlens = [4, 10, 5] |
198 | | - # k_seqlens = [6, 17, 9] |
199 | | - # Then we would get: |
200 | | - # new_tokens_in_first_block = [2, 1, 4] |
201 | | - # local_blocks = [2, 4, 2] |
202 | | - q_tokens_in_first_block = np.minimum( |
203 | | - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), |
204 | | - q_seqlens).astype(np.int32) |
205 | | - tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) |
206 | | - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, |
207 | | - attn_chunk_size) |
208 | | - |
209 | | - # Once we know the number of local blocks we can compute the request spans |
210 | | - # for each batch idx, we can figure out the number of "virtual" requests we |
211 | | - # have to make, |
212 | | - # For the above example we would get: |
213 | | - # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] |
214 | | - # |
215 | | - # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) |
216 | | - # (TODO: max a utility to share this code with _prepare_inputs) |
217 | | - # arange step 1. [2, 4, 2] -> [2, 6, 8] |
218 | | - cu_num_blocks = np.cumsum(local_blocks) |
219 | | - virtual_batches = cu_num_blocks[-1] |
220 | | - # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] |
221 | | - block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) |
222 | | - # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] |
223 | | - arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets |
224 | | - # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) |
225 | | - rarange = np.repeat(local_blocks, local_blocks) - arange - 1 |
226 | | - # Then we can compute the seqlens_q_local, handling the fact that the |
227 | | - # first and last blocks could be partial |
228 | | - seqlens_q_local = \ |
229 | | - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) |
230 | | - # set the first block since this may be a partial block |
231 | | - seqlens_q_local[arange == 0] = q_tokens_in_first_block |
232 | | - # set the remaining blocks |
233 | | - seqlens_q_local[arange > 0] = np.minimum( |
234 | | - seqlens_q_local - attn_chunk_size * (arange - 1), |
235 | | - attn_chunk_size)[arange > 0] |
236 | | - |
237 | | - # convert from q_seqlens to cu_seqlens_q |
238 | | - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ |
239 | | - .astype(np.int32) |
240 | | - |
241 | | - # compute the seqlens_k_local, |
242 | | - # basically a full local attention block for all but the last block in each |
243 | | - # batch |
244 | | - # For our example this will be: |
245 | | - # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] |
246 | | - seqlens_k_local = np.full(cu_num_blocks[-1], |
247 | | - attn_chunk_size, |
248 | | - dtype=np.int32) |
249 | | - seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block |
250 | | - |
251 | | - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ |
252 | | - (rarange * attn_chunk_size + \ |
253 | | - np.repeat(tokens_in_last_block, local_blocks)) |
254 | | - # For the example the local attention blocks start at: |
255 | | - # _b0_ _____b1_____ _b2_ |
256 | | - # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] |
257 | | - block_starts = k_seqstarts_absolute // block_size |
258 | | - assert attn_chunk_size % block_size == 0, \ |
259 | | - f"attn_chunk_size {attn_chunk_size} is not " \ |
260 | | - f"divisible by block_size {block_size}" |
261 | | - pages_per_local_batch = attn_chunk_size // block_size |
262 | | - |
263 | | - # Create a block_table for the local attention blocks |
264 | | - # For out example if we have a block-table like (assuming block_size=2): |
265 | | - # block_table = [ |
266 | | - # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 |
267 | | - # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 |
268 | | - # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 |
269 | | - # ] |
270 | | - # Then for the local batches we would want a block-table like |
271 | | - # block_table_local = [ |
272 | | - # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) |
273 | | - # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) |
274 | | - # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) |
275 | | - # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) |
276 | | - # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) |
277 | | - # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) |
278 | | - # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) |
279 | | - # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) |
280 | | - # ] |
281 | | - block_indices= np.broadcast_to( |
282 | | - np.arange(pages_per_local_batch, dtype=np.int32), |
283 | | - (virtual_batches, pages_per_local_batch)) \ |
284 | | - + np.expand_dims(block_starts, axis=1) |
285 | | - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) |
286 | | - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), |
287 | | - local_blocks * pages_per_local_batch) |
288 | | - block_table_local = block_table[batch_indices, block_indices]\ |
289 | | - .view(virtual_batches, -1) |
290 | | - |
291 | | - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ |
292 | | - block_table_local |
293 | | - |
294 | | - |
295 | 129 | def _get_sliding_window_configs( |
296 | 130 | vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: |
297 | 131 | """Get the set of all sliding window configs used in the model.""" |
|
0 commit comments