@@ -105,6 +105,7 @@ class LocalAttentionMetadata:
105105 local_block_table : torch .Tensor
106106 local_max_query_len : int
107107 local_max_seq_len : int
108+ local_scheduler_metadata : Optional [torch .Tensor ]
108109
109110 local_attn_metadata : Optional [LocalAttentionMetadata ] = None
110111
@@ -282,7 +283,9 @@ def __init__(self, runner: "GPUModelRunner"):
282283
283284 self .runner = runner
284285 self .aot_schedule = (get_flash_attn_version () == 3 )
285- self .num_heads = model_config .get_num_attention_heads (
286+ self .num_heads_q = model_config .get_num_attention_heads (
287+ runner .parallel_config )
288+ self .num_heads_kv = model_config .get_num_kv_heads (
286289 runner .parallel_config )
287290 self .headdim = model_config .get_head_size ()
288291 self .page_size = self .runner .block_size
@@ -304,6 +307,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304307 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
305308 self .runner .device , non_blocking = True ).long ()
306309
310+ def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
311+ max_seq_len , causal ):
312+ if self .aot_schedule :
313+ return get_scheduler_metadata (
314+ batch_size = batch_size ,
315+ max_seqlen_q = max_query_len ,
316+ max_seqlen_k = max_seq_len ,
317+ cache_seqlens = seqlens ,
318+ num_heads_q = self .num_heads_q ,
319+ num_heads_kv = self .num_heads_kv ,
320+ headdim = self .headdim ,
321+ page_size = self .page_size ,
322+ cu_seqlens_q = cu_query_lens ,
323+ causal = causal ,
324+ )
325+ return None
326+
307327 # for local attention
308328 local_attn_metadata = None
309329 if self .runner .attention_chunk_size is not None :
@@ -315,36 +335,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
315335 block_table ,
316336 self .runner .block_size ,
317337 )
338+ local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
339+ self .runner .device , non_blocking = True )
340+ local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
341+ self .runner .device , non_blocking = True )
342+ local_max_query_len = seqlens_q_local_np .max ()
343+ local_max_seq_len = virt_k_seqlens_np .max ()
344+ local_scheduler_metadata = schedule (
345+ batch_size = local_query_start_loc .shape [0 ] - 1 ,
346+ cu_query_lens = local_query_start_loc ,
347+ max_query_len = local_max_query_len ,
348+ seqlens = local_seqused_k ,
349+ max_seq_len = local_max_seq_len ,
350+ causal = True )
351+
318352 local_attn_metadata = FlashAttentionMetadata .LocalAttentionMetadata (
319- local_query_start_loc = torch .from_numpy (
320- virt_q_cu_seqlens_np ).to (self .runner .device ,
321- non_blocking = True ),
322- local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
323- self .runner .device , non_blocking = True ),
353+ local_query_start_loc = local_query_start_loc ,
354+ local_seqused_k = local_seqused_k ,
324355 local_block_table = virt_block_table ,
325- local_max_query_len = seqlens_q_local_np .max (),
326- local_max_seq_len = virt_k_seqlens_np .max (),
356+ local_max_query_len = local_max_query_len ,
357+ local_max_seq_len = local_max_seq_len ,
358+ local_scheduler_metadata = local_scheduler_metadata ,
327359 )
328360
329361 use_cascade = common_prefix_len > 0
330362
331- def schedule (cu_query_lens , max_query_len , seqlens , max_seq_len ,
332- causal ):
333- if self .aot_schedule :
334- return get_scheduler_metadata (
335- batch_size = num_reqs ,
336- max_seqlen_q = max_query_len ,
337- max_seqlen_k = max_seq_len ,
338- cache_seqlens = seqlens ,
339- num_heads_q = self .num_heads ,
340- num_heads_kv = self .num_heads ,
341- headdim = self .headdim ,
342- page_size = self .page_size ,
343- cu_seqlens_q = cu_query_lens ,
344- causal = causal ,
345- )
346- return None
347-
348363 if use_cascade :
349364 cu_prefix_query_lens = torch .tensor ([0 , num_actual_tokens ],
350365 dtype = torch .int32 ,
@@ -357,12 +372,14 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
357372 suffix_kv_lens = torch .from_numpy (suffix_kv_lens ).to (
358373 self .runner .device )
359374 prefix_scheduler_metadata = schedule (
375+ batch_size = num_reqs ,
360376 cu_query_lens = cu_prefix_query_lens ,
361377 max_query_len = num_actual_tokens ,
362378 seqlens = prefix_kv_lens ,
363379 max_seq_len = common_prefix_len ,
364380 causal = False )
365- scheduler_metadata = schedule (cu_query_lens = query_start_loc ,
381+ scheduler_metadata = schedule (batch_size = num_reqs ,
382+ cu_query_lens = query_start_loc ,
366383 max_query_len = max_query_len ,
367384 seqlens = suffix_kv_lens ,
368385 max_seq_len = max_seq_len -
@@ -373,7 +390,8 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
373390 prefix_kv_lens = None
374391 suffix_kv_lens = None
375392 prefix_scheduler_metadata = None
376- scheduler_metadata = schedule (cu_query_lens = query_start_loc ,
393+ scheduler_metadata = schedule (batch_size = num_reqs ,
394+ cu_query_lens = query_start_loc ,
377395 max_query_len = max_query_len ,
378396 seqlens = seq_lens ,
379397 max_seq_len = max_seq_len ,
@@ -540,12 +558,14 @@ def forward(
540558 max_seqlen_q = local_metadata .local_max_query_len
541559 max_seqlen_k = local_metadata .local_max_seq_len
542560 block_table = local_metadata .local_block_table
561+ scheduler_metadata = local_metadata .local_scheduler_metadata
543562 else :
544563 cu_seqlens_q = attn_metadata .query_start_loc
545564 seqused_k = attn_metadata .seq_lens
546565 max_seqlen_q = attn_metadata .max_query_len
547566 max_seqlen_k = attn_metadata .max_seq_len
548567 block_table = attn_metadata .block_table
568+ scheduler_metadata = attn_metadata .scheduler_metadata
549569
550570 descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
551571
@@ -564,7 +584,7 @@ def forward(
564584 window_size = self .sliding_window ,
565585 block_table = block_table ,
566586 softcap = self .logits_soft_cap ,
567- scheduler_metadata = attn_metadata . scheduler_metadata ,
587+ scheduler_metadata = scheduler_metadata ,
568588 fa_version = self .vllm_flash_attn_version ,
569589 q_descale = layer ._q_scale .expand (descale_shape ),
570590 k_descale = layer ._k_scale .expand (descale_shape ),
0 commit comments