diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 93002012799a..1f4efb556442 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -159,6 +159,10 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + split_prefill_from_chunk: bool = False + """Whether to split the prefill request into pure prefill and chunked prefill in a single + batch.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c6d14aa87c7f..65a816e68a95 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -340,6 +340,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA: + # enable the request reorder if we are using AITER MHA for calculation + vllm_config.scheduler_config.split_prefill_from_chunk = True + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 173a0a255e49..7e61f6b87e47 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,27 +2,192 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, ClassVar import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_prefills_and_chunk) +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 +_CHUNK_PREFILL_TOKENS_PER_ITER_ROCM = 32 * 1024 + +KV_CACHE_LAYOUT_V0 = False + if current_platform.is_rocm(): import aiter - from vllm.triton_utils import tl, triton + # from vllm.triton_utils import tl, triton + import triton + import triton.language as tl from vllm.utils import direct_register_custom_op + from aiter.ops.triton.utils.device_info import get_num_sms + + def block_size(x, head_dim): + return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) + + def num_programs(head_dim): + return min(head_dim, get_num_sms()) + + @triton.jit + def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, num_heads, head_size / x, page_size, x] or [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, num_heads, head_size, page_size] or [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] note: max_cum_tokens should always larger or equal than max_tokens + seq_start_ptr, # [num_batches] + k_scale_ptr, + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + num_tokens, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_PRGMS: tl.constexpr + ): + bid = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + + for token_id in tl.range(bid, num_tokens, NUM_PRGMS): + key_ptr_offset = key_ptr + token_id * head_size * num_heads + value_ptr_offset = value_ptr + token_id * head_size * num_heads + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "v0": + # For kv cache layout as + # K: [num_blocks, num_heads, head_size / x, page_size, x] + # V: [num_blocks, num_heads, head_size, page_size] + key_cache_ptr_offset = key_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * x + value_cache_ptr_offset = value_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id + # since the num_head and head_dim are not contiguous, we use two loop the iter over the data + for head in tl.range(0, num_heads): + src_head_offset = head * PAGE_SIZE * head_size + dst_head_offset = head * head_size + for i in tl.range(0, head_size, BLOCK_SIZE): + mask = (col_offsets + i) < head_size + k_offset = (col_offsets + i) // x * PAGE_SIZE * x + col_offsets % x + k_reg = tl.load(key_cache_ptr_offset + src_head_offset + k_offset, mask=mask) + v_offset = (col_offsets + i) * PAGE_SIZE + v_reg = tl.load(value_cache_ptr_offset + src_head_offset + v_offset, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + + k_reg = (k_reg.to(tl.float32) * v_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * k_scale).to(v_dtype) + + tl.store(key_ptr_offset + dst_head_offset + col_offsets, k_reg, mask=mask) + tl.store(value_ptr_offset + dst_head_offset + col_offsets, v_reg, mask=mask) + elif CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = key_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + value_cache_ptr_offset = value_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): + mask = (col_offsets + i) < head_size * num_heads + k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) + v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) + tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) + + + def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: float, + v_scales: float, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int + ): + assert kv_cache_layout in ["v0", "NHD", "HND"], "kv_cache_layout only support v0, NHD, HND" + head_dim = key.shape[2] + x = 0 + assert dequant is True, "Currently, we only support gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, head_dim / x, page_size, x] + if kv_cache_layout == "v0": + x = key_cache.shape[4] + num_heads = key.shape[1] + page_size = key_cache.shape[3] + assert x * key_cache.shape[2] == head_dim, "We assume your kv cache layout is [num_blocks, num_heads, head_dim/x, page_size, x], but got otherwise" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + elif kv_cache_layout == "HND": + assert False + assert head_dim == key_cache.shape[3], "We assume your kv cache layout is [num_blocks, num_heads, page_size, head_dim], but got otherwise" + page_size = key_cache.shape[2] + num_heads = key_cache.shape[1] + elif kv_cache_layout == "NHD": + assert head_dim == key_cache.shape[3], "We assume your kv cache layout is [num_blocks, page_size, num_heads, head_dim], but got otherwise" + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + else: + raise RuntimeError + + NUM_PRGMS = num_programs(total_tokens) + BLOCK_SIZE = block_size(key_cache, head_dim) + grid = lambda meta: (NUM_PRGMS, ) + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + total_tokens, + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS + ) + @triton.jit def _vllm_layout_trans_kernel( @@ -36,6 +201,7 @@ def _vllm_layout_trans_kernel( block_table_stride_0, k_scale, v_scale, + skip_query: tl.constexpr, output_dtype: tl.constexpr, E_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -43,13 +209,14 @@ def _vllm_layout_trans_kernel( batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) - batch_query_start, batch_query_end = tl.split(batch_query_indexes) - query_len = batch_query_end - batch_query_start + if skip_query: + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + + tl.arange(0, 2)) + batch_query_start, batch_query_end = tl.split(batch_query_indexes) + query_len = batch_query_end - batch_query_start - if query_len <= 1: - return + if query_len <= 1: + return batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) @@ -116,6 +283,9 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, output_dtype = tl.bfloat16 else: raise ValueError(f"Unsupported output dtype: {output_dtype}") + skip_query = False + if b_query_lens_loc is None: + skip_query = True _vllm_layout_trans_kernel[grid](k_cache, v_cache, @@ -128,6 +298,7 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, k_scale, v_scale, output_dtype=output_dtype, + skip_query=skip_query, E_DIM=H_KV * D, BLOCK_SIZE=BLOCK_SIZE) @@ -201,9 +372,43 @@ def flash_attn_varlen_func_fake( flash_attn_varlen_func_fake, dispatch_key=current_platform.dispatch_key) -logger = init_logger(__name__) +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + @dataclass class AiterFlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -222,7 +427,18 @@ class AiterFlashAttentionMetadata: seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor - cu_seq_lens: Optional[torch.Tensor] + + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_chunk_prefills: int + num_chunk_prefill_tokens: int + + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] + pure_prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] + chunk_prefill_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] # For cascade attention. use_cascade: bool @@ -232,7 +448,9 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - cudagraph_support = AttentionCGSupport.ALWAYS + cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + reorder_batch_threshold: ClassVar[int] = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -254,6 +472,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 + self.chunk_prefill_workspace_size = _CHUNK_PREFILL_TOKENS_PER_ITER_ROCM * self.num_heads_kv * self.headdim + + self.chunk_prefill_workspace = torch.empty( + [2, _CHUNK_PREFILL_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], + dtype=self.model_config.dtype, + device=device + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata): self.total_tokens = self.model_config.max_model_len \ @@ -268,44 +494,108 @@ def build(self, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - if max_query_len > 1: - # We pre-compute cumulative seq len needed for prefill attention - # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) - num_actual_kv_tokens = int(cu_seq_lens[-1].item()) - else: - cu_seq_lens = None - num_actual_kv_tokens = 0 + split_ret = \ + split_decodes_prefills_and_chunk(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + num_decodes, num_chunk_prefills, num_pure_prefills, num_decode_tokens, num_chunk_prefill_tokens, num_pure_prefill_tokens = split_ret - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): - return None + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[:num_decodes + 1] + ) + + pure_prefill_metadata = None + if num_pure_prefills > 0: + query_lens_for_pure_prefill = query_lens_cpu[num_decodes + num_chunk_prefills:] + query_start_loc_device = common_attn_metadata.query_start_loc[num_decodes + num_chunk_prefills:] + pure_prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_pure_prefill.max().item(), + min_query_len=query_lens_for_pure_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_chunk_prefills:].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0] + ) + + chunk_prefill_metadata = None + if num_chunk_prefills > 0: + query_lens_for_chunk_prefill = query_lens_cpu[num_decodes:num_decodes + num_chunk_prefills] + seq_lens_for_chunk_prefill = common_attn_metadata.seq_lens_cpu[num_decodes: num_decodes + num_chunk_prefills] + computed_kv_lens = seq_lens_for_chunk_prefill - query_lens_for_chunk_prefill + + # allocate the equal amount of workspace for each chunk prefill request + max_context_chunk = (_CHUNK_PREFILL_TOKENS_PER_ITER_ROCM // num_chunk_prefills) + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + + chunk_starts = torch.arange(num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, num_chunk_prefills) * max_context_chunk + chunk_ends = torch.min(computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) # [num_chunks, num_chunk_prefills] + cu_seq_lens_cpu = torch.zeros([num_chunks, num_chunk_prefills + 1], dtype=torch.int32, pin_memory=True) + torch.cumsum(chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + + range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :] # [num_chunks, num_chunk_prefills, max_cum_tokens] + idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None] # [num_chunks, num_chunk_prefills, max_cum_tokens] + idx_to_batch_tensor = idx_to_batch_tensor.sum(dim=1) # [num_chunks, max_cum_tokens] + token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1) + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.chunk_prefill_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist() + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[num_decodes:num_decodes + num_chunk_prefills + 1] + seq_lens_device = common_attn_metadata.seq_lens[num_decodes:num_decodes + num_chunk_prefills] + cu_seq_lens = torch.zeros(num_chunk_prefills + 1, dtype=torch.int32, device=seq_lens_device.device) + torch.cumsum(seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) + chunk_prefill_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_chunk_prefill.max().item(), + min_query_len=query_lens_for_chunk_prefill.min().item(), + max_seq_len=seq_lens[num_decodes:num_decodes + num_chunk_prefills].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata + ) + + num_actual_kv_tokens = torch.sum(seq_lens).item() use_cascade = common_prefix_len > 0 attn_metadata = AiterFlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - cu_seq_lens=cu_seq_lens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_pure_prefills, + num_prefill_tokens=num_pure_prefill_tokens, + num_chunk_prefills=num_chunk_prefills, + num_chunk_prefill_tokens=num_chunk_prefill_tokens, + decode_metadata=decode_metadata, + pure_prefill_metadata=pure_prefill_metadata, + chunk_prefill_metadata=chunk_prefill_metadata, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -364,7 +654,10 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + if KV_CACHE_LAYOUT_V0: + return (2, num_blocks, num_kv_heads, block_size, head_size) + else: + return (2, num_blocks, block_size, num_kv_heads, head_size) class AiterFlashAttentionImpl(AttentionImpl): @@ -411,6 +704,113 @@ def __init__( "are not implemented for " "FlashAttentionImpl") + + def chunk_prefill_forward( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: float, + v_scale: float, + ): + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True + ) + chunk_context_metadata = attn_metadata.chunk_prefill_metadata.chunk_context_metadata + seq_lens = chunk_context_metadata.seq_lens + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched= workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=True, + kv_cache_layout="v0" if KV_CACHE_LAYOUT_V0 else "NHD", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True + ) + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + + def forward( self, layer: torch.nn.Module, @@ -430,7 +830,10 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [2, num_blocks, block_size, num_kv_heads, head_size] + [2, num_blocks, block_size * num_kv_heads * head_size] + more specifically: + k_cache = [num_blocks, num_kv_heads, head_dim / x, block_size, x] + v_cache = [num_blocks, num_kv_heads, block_size, head_dim] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -449,6 +852,7 @@ def forward( # Profiling run. return output + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -457,7 +861,6 @@ def forward( # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: @@ -468,82 +871,168 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + if KV_CACHE_LAYOUT_V0: + num_blocks = key_cache.shape[0] + num_heads = key_cache.shape[1] + block_size = key_cache.shape[2] + head_size = key.shape[2] + x = 16 // key_cache.dtype.itemsize + + key_cache = key_cache.view([num_blocks, num_heads, head_size // x, block_size, x]) + value_cache = value_cache.view([num_blocks, num_heads, head_size, block_size]) + torch.ops._C_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(torch.float8_e4m3fnuz) + # decode:chunk_prefill:pure_prefill + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + block_table = attn_metadata.block_table + num_decodes = attn_metadata.num_decodes + num_pure_prefills = attn_metadata.num_prefills + num_chunk_prefills = attn_metadata.num_chunk_prefills + + num_pure_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_chunk_prefill_tokens = attn_metadata.num_chunk_prefill_tokens if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if max_seqlen_q > 1: - torch.ops.vllm.flash_attn_varlen_func( - query[:num_actual_tokens], - key_cache, - value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + + # calculate for pure prefills + if num_pure_prefills > 0: + + prefill_query = query[num_decode_tokens + num_chunk_prefill_tokens:] + prefill_key = key[num_decode_tokens + num_chunk_prefill_tokens:] + prefill_value = value[num_decode_tokens + num_chunk_prefill_tokens:] + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.pure_prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.pure_prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.pure_prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.pure_prefill_metadata.max_seq_len, + min_seqlen_q=attn_metadata.pure_prefill_metadata.min_query_len, + dropout_p=0.0, softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, + causal=True, window_size=self.sliding_window, - block_table=block_table, - cu_seqlens_k=attn_metadata.cu_seq_lens, + alibi_slopes=self.alibi_slopes, + out=output_actual_tokens[num_decode_tokens + num_chunk_prefill_tokens:], + ) + + # calculate for chunk prefills + if num_chunk_prefills > 0: + chunk_prefill_querys = query[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_keys = key[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_values = value[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_outputs = output[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + self.chunk_prefill_forward( + attn_metadata=attn_metadata, + query=chunk_prefill_querys, + key=chunk_prefill_keys, + value=chunk_prefill_values, + key_cache=key_cache, + value_cache=value_cache, + output=chunk_prefill_outputs, + cu_seqlens_q=attn_metadata.chunk_prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.chunk_prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.chunk_prefill_metadata.max_seq_len, + min_seqlen_q=attn_metadata.chunk_prefill_metadata.min_query_len, + block_table=attn_metadata.block_table[num_decodes:num_decodes + num_chunk_prefills], + slot_mapping=attn_metadata.slot_mapping[num_decodes:num_decodes + num_chunk_prefills], k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.num_actual_kv_tokens, ) - _, num_heads, head_size = query.shape - nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM - - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, - device=output.device, - ) - - torch.ops.aiter.paged_attention_v1( - output[:num_actual_tokens], - workspace_buffer, - query[:num_actual_tokens], - key_cache, - value_cache, - self.scale, - block_table, - cu_seqlens_q, - seqused_k, - max_seqlen_k, - self.alibi_slopes, - self.kv_cache_dtype, - "NHD", - self.logits_soft_cap, - layer._k_scale, - layer._v_scale, - None, - _PARTITION_SIZE_ROCM, - ) - return output + # calculate for decodes + if num_decodes > 0: + if KV_CACHE_LAYOUT_V0: + # ============= spec decode ================= + # kv cache layout: [num_blocks, num_heads, head_dim / x, page_size, x] + from aiter.paged_attn import PagedAttention + # for spec decode impl + decode_output = PagedAttention.forward_decode( + query[:num_decode_tokens], + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_table[:num_decode_tokens], + seq_lens=attn_metadata.seq_lens[:num_decodes], + max_seq_len=attn_metadata.decode_metadata.max_seq_len, + kv_cache_dtype=self.kv_cache_dtype, + num_kv_heads=self.num_kv_heads, + scale=self.scale, + alibi_slopes=self.alibi_slopes, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + mtp=attn_metadata.decode_metadata.max_query_len + ) + output_actual_tokens[:num_decode_tokens] = decode_output + # ============= spec decode ================= + else: + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + max_num_partitions = (attn_metadata.decode_metadata.max_seq_len + _PARTITION_SIZE_ROCM - + 1) // _PARTITION_SIZE_ROCM + + workspace_buffer = torch.empty( + (num_decode_tokens * num_heads * max_num_partitions * head_size) * + nbytes_per_qo_elem + 2 * + (num_decode_tokens * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) + + torch.ops.aiter.paged_attention_v1( + output_actual_tokens[:num_decode_tokens], + workspace_buffer, + query[:num_decode_tokens], + key_cache, + value_cache, + self.scale, + attn_metadata.block_table[:num_decodes], + attn_metadata.decode_metadata.query_start_loc, + attn_metadata.seq_lens[:num_decodes], + attn_metadata.decode_metadata.max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + "NHD", + self.logits_soft_cap, + layer._k_scale, + layer._v_scale, + None, + _PARTITION_SIZE_ROCM, + ) else: raise NotImplementedError( "Cascade attention is not implemented for ROCM AITER") + + return output + diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b286a4ba9fe5..af01a040ecc9 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, fields, make_dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, TypeVar) +from collections import deque import numpy as np import torch @@ -642,6 +643,67 @@ def subclass_attention_backend( {"get_builder_cls": lambda: builder_cls}) +def split_decodes_prefills_and_chunk( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, 0, num_tokens, 0, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, 0, num_tokens, 0, 0 + + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold), f"got query lens: {query_lens[first_prefill:]} and decode threshold {decode_threshold}" + assert torch.all(query_lens[:first_prefill] <= decode_threshold), f"got query lens: {query_lens[:first_prefill]} and decode threshold {decode_threshold}" + num_decodes = first_prefill + num_decode_tokens = query_start_loc[first_prefill].item() + + query_lens_prefill = query_lens[first_prefill:] + seq_lens_prefill = seq_lens[first_prefill:] + is_pure_prefill = seq_lens_prefill == query_lens_prefill + + if torch.all(is_pure_prefill): + num_pure_prefills = num_reqs - num_decodes + num_pure_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, 0, num_pure_prefills, num_decode_tokens, 0, num_pure_prefill_tokens) + + num_prefills = num_reqs - num_decodes + num_prefill_tokens = num_tokens - num_decode_tokens + first_chunk_prefill = is_pure_prefill.int().argmax(dim=-1).item() + + num_chunk_prefills = first_chunk_prefill + num_pure_prefills = num_prefills - first_chunk_prefill + + num_chunk_prefill_tokens = query_start_loc[num_chunk_prefills + num_decodes].item() - num_decode_tokens + num_pure_prefill_tokens = num_tokens - num_decode_tokens - num_chunk_prefill_tokens + return (num_decodes, num_chunk_prefills, num_pure_prefills, num_decode_tokens, num_chunk_prefill_tokens, num_pure_prefill_tokens) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, @@ -684,10 +746,138 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def reorder_batch_to_split_decodes_prefills_and_chunks( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill, chunk_prefill and decode requests; places all + requests in the order of [decodes:chunked_prefills:pure_prefills]. + + Returns: + True if the batch was modified, False otherwise. + """ + + # We assume most of the request is already in the order of what we desired since this function + # should only be opened after the `SchedulerConfig.split_prefill_from_chunk` is True. So we only + # need to spot all mismatched request and swap their positions for efficiency. + + decodes = [] + prefills = [] + chunk_prefills = [] + + def print_order_of_batch(): + new_decode = [] + new_chunk_prefill = [] + new_prefills = [] + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= decode_threshold: + new_decode.append(i) + elif input_batch.num_computed_tokens_cpu[i] > 0: + # print("found one chunk prefill request, computed token is: ", input_batch.num_computed_tokens_cpu[i]) + new_chunk_prefill.append(i) + else: + new_prefills.append(i) + print("decodes: ", new_decode) + print("append: ", new_chunk_prefill) + print("prefills: ", new_prefills) + + # print("into split d p c") + # print('before reorder') + # print_order_of_batch() + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= decode_threshold: + decodes.append(i) + elif input_batch.num_computed_tokens_cpu[i] > 0: + # print("found one chunk prefill request, computed token is: ", input_batch.num_computed_tokens_cpu[i]) + chunk_prefills.append(i) + else: + prefills.append(i) + + num_decodes = len(decodes) + num_chunk_prefills = len(chunk_prefills) + # We define the reorder matrix here to help on the request reorder + # reorder_matrix[(i, j)] means the id the the requests that suppose to be in + # zone i but actually spot on zone j + # The decode, chunk prefill and pure prefill are separated into 3 different zone + # here, 0 for decode, 1 for chunk prefill and 2 for pure prefill + reorder_matrix = {(i, j): deque() for i in range(3) for j in range(3) if i!=j} + # collect mismatch + + def target_idx(idx): + if idx < num_decodes: + # decode as zone 0 + return 0 + elif idx < num_decodes + num_chunk_prefills: + # chunk prefill as zone 1 + return 1 + else: + # pure prefill as zone 2 + return 2 + + def fill_reorder_matrix(request_lists, reorder_sequence): + for idx, seq in enumerate(reorder_sequence): + request_list = request_lists[idx] + for req_idx in request_list: + req_target_id = target_idx(req_idx) + if seq != req_target_id: + reorder_matrix[(seq, req_target_id)].append(req_idx) + # print("reorder matrix: ", reorder_matrix) + + def direct_zone_swap(i, j): + assert i != j + modified_batch = False + while reorder_matrix[(i, j)] and reorder_matrix[(j, i)]: + swap_req1 = reorder_matrix[(i, j)].pop() + swap_req2 = reorder_matrix[(j, i)].pop() + input_batch.swap_states(swap_req1, swap_req2) + modified_batch = True + + return modified_batch + + # in order 1,2,3, out order 3, 1, 2 + def indirect_zone_swap(zone_list): + assert len(zone_list) == 3 + modified_batch = False + while reorder_matrix[zone_list[0]] and reorder_matrix[zone_list[1]] and reorder_matrix[zone_list[2]]: + swap_req1 = reorder_matrix[zone_list[0]].pop() + swap_req2 = reorder_matrix[zone_list[1]].pop() + swap_req3 = reorder_matrix[zone_list[2]].pop() + # print("do indirect swap: ", swap_req1, swap_req2, swap_req3) + # print("desired order should be : ", swap_req3, swap_req1, swap_req2) + + input_batch.swap_states(swap_req1, swap_req2) + input_batch.swap_states(swap_req2, swap_req3) + modified_batch = True + return modified_batch + + + fill_reorder_matrix([decodes, chunk_prefills, prefills], [0, 1, 2]) + + modified_batch = False + # do directly swap for + modified_batch &= direct_zone_swap(0, 1) # decode <--> chunk prefill + modified_batch &= direct_zone_swap(0, 2) # decode <--> pure prefill + modified_batch &= direct_zone_swap(1, 2) # chunk prefill <--> pure prefill + + modified_batch &= indirect_zone_swap(((0, 1), (1, 2), (2, 0))) + modified_batch &= indirect_zone_swap(((2, 1), (0, 2), (1, 0))) + + # print("after reorder") + # print_order_of_batch() + + return modified_batch + + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", decode_threshold: int = 1, + reorder_append_prefills: bool = False, ) -> bool: """ Reorders the batch to split into prefill and decode requests; places all diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d40e96632c9..5becc44a3407 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -187,6 +187,8 @@ def schedule(self) -> SchedulerOutput: # and the "jump decoding" optimization in the future. scheduled_new_reqs: list[Request] = [] + new_reqs_for_pure_preill: list[Request] = [] + new_reqs_for_chunk_prefill: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] @@ -498,18 +500,23 @@ def schedule(self) -> SchedulerOutput: request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + if num_computed_tokens > 0 and self.scheduler_config.split_prefill_from_chunk: + new_reqs_for_chunk_prefill.append(request) + else: + new_reqs_for_pure_preill.append(request) + req_index += 1 - self.running.append(request) + # self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) - if request.status == RequestStatus.WAITING: - scheduled_new_reqs.append(request) - elif request.status == RequestStatus.PREEMPTED: - scheduled_resumed_reqs.append(request) - else: - raise RuntimeError( - f"Invalid request status: {request.status}") + # if request.status == RequestStatus.WAITING: + # scheduled_new_reqs.append(request) + # elif request.status == RequestStatus.PREEMPTED: + # scheduled_resumed_reqs.append(request) + # else: + # raise RuntimeError( + # f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) @@ -517,7 +524,7 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens - request.status = RequestStatus.RUNNING + # request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens. if request.num_cached_tokens < 0: @@ -531,6 +538,22 @@ def schedule(self) -> SchedulerOutput: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + # reorder the request during scheduling, put chunked prefill at the top of + # the scheduled_new_reqs to make sure the actual reorder in model runner + # happens as less as possible. + new_reqs_for_chunk_prefill.extend(new_reqs_for_pure_preill) + for req in new_reqs_for_chunk_prefill: + self.running.append(req) + + if req.status == RequestStatus.WAITING: + scheduled_new_reqs.append(req) + elif req.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(req) + else: + raise RuntimeError( + f"Invalid request status: {req.status}") + req.status = RequestStatus.RUNNING + # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 897c3a621320..971b4be09f13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -59,7 +59,8 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills) + reorder_batch_to_split_decodes_and_prefills, + reorder_batch_to_split_decodes_prefills_and_chunks) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -443,10 +444,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: if self.dcp_world_size > 1: assert self.reorder_batch_threshold == 1, \ "DCP not support reorder_batch_threshold > 1 now." - reorder_batch_to_split_decodes_and_prefills( - self.input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) + if self.scheduler_config.split_prefill_from_chunk: + reorder_batch_to_split_decodes_prefills_and_chunks( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + else: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) # Note: used for model runner override. def _init_device_properties(self) -> None: @@ -542,7 +549,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) - reqs_to_add.append(req_state) # Update the states of the running/resumed requests. @@ -591,6 +597,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) continue @@ -629,6 +636,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests