diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index cb2037b575e5..b02ece3cae10 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -57,6 +57,7 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) + ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") ### FusedMoEPermuteExpertsUnpermute diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 406669741914..1664af684a00 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -90,7 +90,13 @@ def parse_args(): parser.add_argument( "--enable-microbatching", action="store_true", - help=("Enable microbatched execution"), + help=("Enable microbatched execution") + ) + parser.add_argument( + "--compilation-config", + type=int, + default=0, + help=("Compilation optimization (O) level 0-3."), ) parser.add_argument( "--compilation-config", diff --git a/vllm/compilation/ubatch_wrapper.py b/vllm/compilation/ubatch_wrapper.py index 3e348c890800..c1e2fd548454 100644 --- a/vllm/compilation/ubatch_wrapper.py +++ b/vllm/compilation/ubatch_wrapper.py @@ -15,7 +15,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts, Schedule logger = init_logger(__name__) @@ -40,13 +40,15 @@ class CUDAGraphMetaData: class UBatchWrapper: def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, device: torch.cuda.device): + runtime_mode: CUDAGraphMode, device: torch.cuda.device, + delayed_start: bool = False): + self.runnable = runnable self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.comm_stream = torch.cuda.Stream(device=device) self.ready_barrier = threading.Barrier(3) - + self.delayed_start = delayed_start self.cudagraphs: dict[int, CUDAGraphMetaData] = {} self.cudagraph_wrapper = None @@ -185,7 +187,8 @@ def _ubatch_thread(results, model, ubatch_metadata): def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, positions, inputs_embeds, intermediate_tensors, compute_stream, dp_metadata, batch_descriptor, - cudagraph_runtime_mode) -> list[UbatchMetadata]: + cudagraph_runtime_mode, + delayed_start: bool = False) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] @@ -198,12 +201,20 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode)) + # Map CLI/config schedule string to Schedule enum + schedule_str = self.vllm_config.parallel_config.microbatch_schedule + schedule = Schedule.MLP_OVERLAP + if schedule_str == Schedule.ATTN_SHARED_OVERLAP.value: + schedule = Schedule.ATTN_SHARED_OVERLAP + ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), comm_stream=self.comm_stream, compute_stream=compute_stream, forward_contexts=forward_contexts, - ready_barrier=self.ready_barrier) + ready_barrier=self.ready_barrier, + delayed_start=delayed_start) + ubatch_metadata: list[UbatchMetadata] = [] for i, ubatch_slice in enumerate(ubatch_slices): diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index d467d47e5d81..6ac162e892f7 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -135,6 +135,13 @@ class ParallelConfig: request is greater than this threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" + microbatch_schedule: Literal["mlp_overlap", "ATTN_SHARED_OVERLAP"] = "mlp_overlap" + """Schedule policy for microbatch overlap coordination. + + - "mlp_overlap": overlap MLP compute and communication across ubatches + - "ATTN_SHARED_OVERLAP": overlap MLA attention and communication across ubatches + """ + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9fe1bcbd77cb..b837e6bba561 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -317,6 +317,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_microbatching: bool = ParallelConfig.enable_microbatching microbatching_token_threshold: int = ParallelConfig.microbatching_token_threshold + microbatch_schedule: str = ParallelConfig.microbatch_schedule eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb num_redundant_experts: int = EPLBConfig.num_redundant_experts @@ -682,6 +683,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["enable_microbatching"]) parallel_group.add_argument("--microbatching-token-threshold", **parallel_kwargs["microbatching_token_threshold"]) + parallel_group.add_argument( + "--microbatch-schedule", + dest="microbatch_schedule", + **parallel_kwargs["microbatch_schedule"]) + parallel_group.add_argument("--enable-async-comms", + **parallel_kwargs["enable_async_comms"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", @@ -1304,6 +1311,7 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, enable_microbatching=self.enable_microbatching, microbatching_token_threshold=self.microbatching_token_threshold, + microbatch_schedule=self.microbatch_schedule, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 2bbe523b4bf9..2b702815393d 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -58,31 +58,63 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_dispatch( + def _create_prepare_ops( self, - tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, - num_experts: int, + a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> Callable: + ) -> mk.PrepareResultType: + + # Apply router weights on input if requested (only supports topk=1) + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + # Quantize prior to dispatch for block-quantized path, otherwise defer + if quant_config.is_block_quantized: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + a1_post_scale = None + else: + a1q = a1 + a1q_scale = None + a1_post_scale = a1_scale - has_scales = token_scales is not None + # Inline dispatch (sync send+recv) + has_scales = a1q_scale is not None (num_tokens_per_rank, num_tokens_per_rdma_rank, dispatch_expert_num_tokens, is_token_in_rank, event) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, + topk_idx=topk_ids, num_experts=num_experts, previous_event=None, async_finish=False, allocate_on_comm_stream=False) - token_data = tokens + token_data: Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]] + token_data = a1q if has_scales: - token_data = (tokens, token_scales) + token_data = (a1q, a1q_scale) + + ######################################################################## + yield # Pre-dispatch done + ######################################################################## ( token_data, expert_topk_ids, expert_topk_weights, @@ -94,10 +126,8 @@ def _do_dispatch( num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, num_tokens_per_expert=dispatch_expert_num_tokens, - topk_idx=rank_topk_ids, - topk_weights=rank_topk_weights, - # expert_alignment rounds the number of tokens per expert - # to this value. + topk_idx=topk_ids, + topk_weights=topk_weights, expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, @@ -131,9 +161,12 @@ def _receiver( if self.async_prepare: event.current_stream_wait() + # Unpack token data if has_scales: + assert isinstance(token_data, tuple) expert_x, expert_x_scale = token_data else: + assert isinstance(token_data, torch.Tensor) expert_x, expert_x_scale = token_data, None # The existing MOE kernels assume that all entries of topk_ids are @@ -174,58 +207,14 @@ def _receiver( per_act_token_quant=False, block_shape=quant_config.block_shape) + ######################################################################## + yield # Dispatch send+recv done (sync) + ######################################################################## + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def supports_async(self) -> bool: - return True - - def prepare_async( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> Callable: - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * topk_weights.to(a1.dtype) - - if quant_config.is_block_quantized: - # Quant and Dispatch - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - if a1q_scale is not None and a1q_scale.numel() == 1: - a1q_scale = a1q_scale.view(1, 1) - a1_post_scale = None - else: - a1q = a1 - a1q_scale = None - a1_post_scale = a1_scale - - return self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config) - - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -236,14 +225,14 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) - return receiver() - - def finalize( + ) -> mk.SyncPrepareOps: + return mk.SyncPrepareOps.from_generator( + self._create_prepare_ops(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config)) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -268,6 +257,10 @@ def finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) + ######################################################################## + yield # Pre-combine done + ######################################################################## + combined_x, _, event = self.buffer.combine( x=fused_expert_output, handle=self.handle, @@ -278,3 +271,24 @@ def finalize( allocate_on_comm_stream=False) # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) + + ######################################################################## + yield # Combine send-recv done + ######################################################################## + + return None + + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.SyncFinalizeOps: + return mk.SyncFinalizeOps.from_generator( + self._create_finalize_ops(output, fused_expert_output, + topk_weights, topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl)) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 76b71cd415d4..926dade1147f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from typing import Optional, Union, Callable import deep_ep import torch @@ -11,11 +11,7 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) -from vllm.v1.worker.ubatching import (dbo_enabled, - dbo_current_ubatch_id, - dbo_yield, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook) +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -117,7 +113,7 @@ def _do_quant( def supports_async(self) -> bool: return True - def prepare_async( + def _create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -128,10 +124,9 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[Callable, mk.ReceiverType]: - - hidden_size = a1.size(1) + ) -> mk.PrepareResultType: a2a_idx = dbo_current_ubatch_id() + hidden_size = a1.size(1) if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -149,10 +144,14 @@ def prepare_async( assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) + + ######################################################################## + yield # Pre-dispatch done + ######################################################################## # Dispatch - expert_x, expert_num_tokens, handle, _, hook= \ - self.buffer.low_latency_dispatch(a1, + _expert_x, expert_num_tokens, handle, _, recv_hook= \ + self.buffers[a2a_idx].low_latency_dispatch(a1, topk_ids, self.max_tokens_per_rank, num_experts, @@ -160,32 +159,28 @@ def prepare_async( async_finish=False, return_recv_hook=True) self.handles[a2a_idx] = handle - - return (hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens, - a1_scale, a1.dtype, quant_config)) - - def _receiver( - self, - hook: Optional[Callable], - expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - expert_num_tokens: torch.Tensor, - a1_scale, - a1_dtype, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - if hook is not None: - hook() + + ######################################################################## + yield # Dispatch send done + ######################################################################## + + recv_hook() + + ######################################################################## + yield # Dispatch recv done + ######################################################################## expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, + _expert_x, a1_scale, a1.dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) return expert_x, expert_x_scale, expert_tokens_meta, None, None - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -196,14 +191,20 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) - return receiver() - - def finalize( + ) -> mk.AsyncPrepareOps: + return mk.AsyncPrepareOps.from_generator( + self._create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config)) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -217,7 +218,6 @@ def finalize( ), ("Weight application and reduction happens in the combine kernel.") a2a_idx = dbo_current_ubatch_id() - do_recv_hook = dbo_enabled() handle = self.handles[a2a_idx] assert handle is not None @@ -225,17 +225,47 @@ def finalize( if apply_router_weight_on_input: # weights have already been applied. combine_topk_weights = torch.ones_like(topk_weights) - - # TODO (varun) : Enable zero copy mode - dbo_maybe_run_recv_hook() - _, _, recv_hook = self.buffer.low_latency_combine(fused_expert_output, - topk_ids, - combine_topk_weights, - handle, - async_finish=False, - zero_copy=False, - return_recv_hook=do_recv_hook, - out=output) - if recv_hook is not None: - dbo_register_recv_hook(recv_hook) - dbo_yield() \ No newline at end of file + + ######################################################################## + yield # Pre-combine done + ######################################################################## + + _, _, recv_hook = self.buffers[a2a_idx].low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=True, + out=output) + + ######################################################################## + yield # Combine send done + ######################################################################## + + recv_hook() + + ######################################################################## + yield # Combine recv done + ######################################################################## + + return None + + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.AsyncFinalizeOps: + return mk.AsyncFinalizeOps.from_generator( + self._create_finalize_ops(output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl)) + diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a82f86f08b2e..f8904dbfaf44 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Callable, Optional, Union, final +from typing import Callable, Optional, Union, final, Generator import torch @@ -13,10 +13,9 @@ from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv -from vllm.v1.worker.ubatching import (dbo_enabled, - dbo_yield, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook) +from vllm.v1.worker.ubatching import (Schedule, dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield, + dbo_current_schedule) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -165,9 +164,122 @@ def apply(self, output: Optional[torch.Tensor], Optional[torch.Tensor], ] -ReceiverType = Callable[[any], PrepareResultType] +ReceiverType = Callable[[], PrepareResultType] +# +# Prepare and Finalize Op Chains +# +# The prepare and finalize functions are broken down into a chain of sequential +# operations/steps. This + +class _PhasedGen[R]: + """ + Enforce an exact number of yields (phases), then a final return. + + Contract: + - The generator must yield exactly `expected_yields` times. + - The next advance must StopIteration with a return value (may be None). + - Early StopIteration or extra yields raise RuntimeError. + - Duplicate step/finish after completion raises RuntimeError. + """ + __slots__ = ("_gen", "_expected", "_steps", "_done", "_ret") + + def __init__(self, gen: Generator[None, None, R], expected_yields: int): + self._gen = gen + self._expected = expected_yields + self._steps = 0 + self._done = False + self._ret: Optional[R] = None + + def step(self, label: str) -> None: + if self._done: + raise RuntimeError(f"Generator already finished; unexpected '{label}'.") + if self._steps >= self._expected: + raise RuntimeError( + f"Too many steps: called '{label}' after {self._expected} phases; " + "expected to finish instead." + ) + try: + next(self._gen) + except StopIteration: + raise RuntimeError( + f"Generator ended early during '{label}' " + f"(completed {self._steps}/{self._expected} phases)." + ) + self._steps += 1 + + def finish(self, label: str) -> R: + if self._done: + raise RuntimeError(f"Generator already finished; duplicate '{label}'.") + if self._steps != self._expected: + raise RuntimeError( + f"Cannot finish at '{label}': only {self._steps}/" + f"{self._expected} phases completed." + ) + try: + next(self._gen) + except StopIteration as e: + self._done = True + self._ret = e.value # may be None + return self._ret # type: ignore[return-value] + else: + raise RuntimeError( + f"Generator yielded more than expected ({self._expected}); " + f"should have finished at '{label}'." + ) + +@dataclass +class AsyncOps[R]: + """ + 3-phase async: + 1) prepare() + 2) send() + 3) recv() + 4) finish() -> R + """ + prepare: Callable[[], None] + send: Callable[[], None] + recv: Callable[[], None] + finish: Callable[[], R] + + @classmethod + def from_generator(cls, gen: Generator[None, None, R]) -> 'AsyncOps[R]': + ph = _PhasedGen[R](gen, expected_yields=3) + return cls( + prepare=lambda: ph.step("prepare"), + send=lambda: ph.step("send"), + recv=lambda: ph.step("recv"), + finish=lambda: ph.finish("finish"), + ) + + +@dataclass +class SyncOps[R]: + """ + 2-phase sync: + 1) prepare() + 2) send_recv() + 3) finish() -> R + """ + prepare: Callable[[], None] + send_recv: Callable[[], None] + finish: Callable[[], R] + + @classmethod + def from_generator(cls, gen: Generator[None, None, R]) -> 'SyncOps[R]': + ph = _PhasedGen[R](gen, expected_yields=2) + return cls( + prepare=lambda: ph.step("prepare"), + send_recv=lambda: ph.step("send_recv"), + finish=lambda: ph.finish("finish"), + ) + +AsyncPrepareOps = AsyncOps[PrepareResultType] +SyncPrepareOps = SyncOps[PrepareResultType] +AsyncFinalizeOps = AsyncOps[None] +SyncFinalizeOps = SyncOps[None] + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -176,7 +288,7 @@ class FusedMoEPrepareAndFinalize(ABC): """ @abstractmethod - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -187,7 +299,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> PrepareResultType: + ) -> Union[SyncPrepareOps, AsyncPrepareOps]: """ Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. @@ -259,7 +371,7 @@ def prepare_async( raise NotImplementedError @abstractmethod - def finalize( + def create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -267,7 +379,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, - ) -> None: + ) -> Union[SyncFinalizeOps, AsyncFinalizeOps]: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -501,12 +613,14 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, class SharedResizableBuffer: + def __init__(self): self.buffer = None - + # NOTE: Assumes the first call to get() is the largest shape, # this is usually true due to the profile run. - def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): + def get(self, shape: tuple[int, ...], device: torch.device, + dtype: torch.dtype): shape_numel = prod(shape) if self.buffer is None or self.buffer.numel() < shape_numel: self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) @@ -583,14 +697,12 @@ def _do_fused_experts( # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = self.workspace13_buffer.get( - workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = self.workspace2_buffer.get( - workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + workspace13 = self.workspace13_buffer.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = self.workspace2_buffer.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -682,10 +794,9 @@ def _maybe_chunk_fused_experts( (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = self.fused_out_buffer.get( - fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + fused_out = self.fused_out_buffer.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -827,55 +938,52 @@ def forward( global_num_experts = local_num_experts shared_output: torch.Tensor + + prepare_ops = self.prepare_finalize.create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) - if not self.prepare_finalize.supports_async(): - # We shouldn't be running an a2a kernel that doesn't - # support async prepare/finalize - assert not dbo_enabled() + prepare_ops.prepare() + if isinstance(prepare_ops, SyncOps): # Run shared experts serially with dispatch. if self.shared_experts is not None: shared_output = self.shared_experts(a1) - - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + prepare_ops.send_recv() else: + assert isinstance(prepare_ops, AsyncOps) + # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() - hook, receiver = self.prepare_finalize.prepare_async( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) - - # assert self.shared_experts is not None - if self.shared_experts is not None: + prepare_ops.send() + + recv_done = dbo_register_recv_hook( + lambda: prepare_ops.recv(), + schedules=(Schedule.MLP_OVERLAP, Schedule.MLP_SHARED_OVERLAP)) + dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) + + # If we are using the MLP_SHARED_OVERLAP schedule, we overlap with + # the combine instead of the dispatch. + # TODO(lucas): refactor this scheduling logic + if self.shared_experts is not None \ + and dbo_current_schedule() != Schedule.MLP_SHARED_OVERLAP: shared_output = self.shared_experts(a1) - dbo_register_recv_hook(hook) - dbo_yield() + dbo_yield(schedules=(Schedule.ATTN_SHARED_OVERLAP, Schedule.MLP_SHARED_OVERLAP)) - if dbo_enabled(): - hook = None + if not recv_done: + prepare_ops.recv() - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = receiver(hook) + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = prepare_ops.finish() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -914,7 +1022,7 @@ def forward( apply_router_weight_on_input=apply_router_weight_on_input, ) - self.prepare_finalize.finalize( + finalize_ops = self.prepare_finalize.create_finalize_ops( output, fused_out, topk_weights, @@ -922,6 +1030,28 @@ def forward( apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) + + if isinstance(finalize_ops, SyncOps): + finalize_ops.prepare() + finalize_ops.send_recv() + finalize_ops.finish() + else: + assert isinstance(finalize_ops, AsyncOps) + finalize_ops.prepare() + dbo_maybe_run_recv_hook() + finalize_ops.send() + + # If we didn't overlap with the dispatch overlap with the combine + # TODO(lucas): refactor this scheduling logic + if self.shared_experts is not None and shared_output is None: + shared_output = self.shared_experts(a1) + + if dbo_register_recv_hook( + lambda: finalize_ops.recv(), all_schedules=True): + dbo_yield(all_schedules=True) + else: + finalize_ops.recv() + finalize_ops.finish() if self.shared_experts is None: return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 2ae79e69f555..f9c11d953dff 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -92,7 +93,7 @@ def num_dispatchers(self) -> int: def supports_async(self) -> bool: return True - def prepare_async( + def _create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -202,7 +203,11 @@ def prepare_async( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - self.a2a.dispatch( + ######################################################################## + yield # Pre-dispatch done + ######################################################################## + + self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, out_expert_x_scale=expert_x_scale, @@ -214,30 +219,11 @@ def prepare_async( do_recv=False, ) - return lambda: self._receiver( - expert_num_tokens, - expert_x, - expert_x_scale, - a1q, - a1q_scale, - topk_ids, - bound_m, - orig_a_scale_block_shape, - ) - - def _receiver( - self, - expert_num_tokens: torch.Tensor, - expert_x: torch.Tensor, - expert_x_scale: Optional[torch.Tensor], - a1q: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - bound_m: Optional[torch.Tensor], - orig_a_scale_block_shape: Optional[int], - ) -> mk.PrepareResultType: + ######################################################################## + yield # Dispatch send done + ######################################################################## - self.a2a.dispatch( + self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, out_expert_x_scale=expert_x_scale, @@ -249,16 +235,21 @@ def _receiver( do_recv=True, ) + ######################################################################## + yield # Dispatch recv done + ######################################################################## + if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) return expert_x, expert_x_scale, expert_tokens_meta, None, None - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -269,21 +260,21 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - receiver = self.prepare_async( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - num_experts, - expert_map, - apply_router_weight_on_input, - quant_config, - ) - return receiver() - - def finalize( + ) -> mk.AsyncPrepareOps: + return mk.AsyncPrepareOps.from_generator( + self._create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + )) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -314,8 +305,55 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - self.a2a.combine(out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m) + ######################################################################## + yield # Pre-combine done + ######################################################################## + + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids.view(dtype=torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + ######################################################################## + yield # Combine send done (no-op for pplx combine) + ######################################################################## + + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids.view(dtype=torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + + ######################################################################## + yield # Combine recv done + ######################################################################## + + return None + + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.AsyncFinalizeOps: + return mk.AsyncFinalizeOps.from_generator( + self._create_finalize_ops( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + )) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bb95a1dbf122..1e8918a8100b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -721,8 +721,10 @@ def forward( return hidden_states + class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): + delayed_dbo_start = True packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b4c9aae254ea..0822a05c5fee 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -218,6 +218,7 @@ infer_global_hyperparameters, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.ubatching import Schedule, dbo_yield try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1274,11 +1275,6 @@ def forward( if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) - if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) - if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( @@ -1313,6 +1309,14 @@ def forward( layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) + dbo_yield(schedules=(Schedule.ATTN_SHARED_OVERLAP,)) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata, layer._k_scale) + + if has_decode: output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 12e79ff165f4..e53f3ce099ff 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -3,6 +3,7 @@ import multiprocessing import os import pickle +import queue import signal import threading import time @@ -18,6 +19,7 @@ from typing import Any, Callable, Optional, Union, cast import cloudpickle +import torch import vllm.envs as envs from vllm.config import VllmConfig @@ -33,7 +35,8 @@ get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, + ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -412,6 +415,14 @@ def __init__( # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_stream = torch.cuda.Stream() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy") + self.async_output_copy_thread.start() + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -593,6 +604,18 @@ class ResponseStatus(Enum): SUCCESS = auto() FAILURE = auto() + def enqueue_worker_output(self, output: Any) -> None: + if isinstance(output, AsyncModelRunnerOutput): + output = output.serialize(self.async_output_copy_stream) + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_worker_output(output) + def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: @@ -617,5 +640,4 @@ def worker_busy_loop(self): continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.async_output_queue.put(output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3..22da4b34e554 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -114,6 +114,34 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +# Contains GPU tensors which must be serialized before sending +# to the scheduler process. +@dataclass +class AsyncModelRunnerOutput: + model_runner_output: ModelRunnerOutput + + # [num_reqs, max_num_generated_tokens] + sampled_token_ids: torch.Tensor + + invalid_req_indices: list[int] + + def serialize(self, copy_stream: torch.cuda.Stream) -> ModelRunnerOutput: + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(copy_stream): + copy_stream.wait_stream(default_stream) + sampled_token_ids_cpu = self.sampled_token_ids.to( + 'cpu', non_blocking=True) + copy_stream.synchronize() + valid_sampled_token_ids = sampled_token_ids_cpu.tolist() + for i in self.invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self.model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + @dataclass class DraftTokenIds: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ad70d9efaaaa..83fc821b8494 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -250,6 +250,11 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 54a51392b813..ebb6fca4c650 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,8 +69,8 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - LogprobsTensors, ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -238,6 +238,8 @@ def __init__( is_pooling_model=self.is_pooling_model, ) + self.use_async_scheduling = self.scheduler_config.async_scheduling + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -709,6 +711,73 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is not None: + # Async scheduling case, we need to copy the sampled token ids + # from the previous iteration. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + current_req_id_to_index = self.input_batch.req_id_to_index + assert prev_req_id_to_index is not None + common_req_ids = set(prev_req_id_to_index.keys()).intersection( + set(current_req_id_to_index.keys())) + if common_req_ids: + current_common_req_indices = [ + current_req_id_to_index[req_id] + for req_id in common_req_ids + ] + prev_common_req_indices = [ + prev_req_id_to_index[req_id] for req_id in common_req_ids + ] + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_indices = [ + int(cu_num_tokens[idx]) - 1 + for idx in current_common_req_indices + ] + if len(flattened_indices) < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if flattened_indices == prev_common_req_indices and \ + set(flattened_indices) == \ + set(range(len(flattened_indices))): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 + self.input_ids.gpu[:len(flattened_indices)].copy_( + self.input_batch.prev_sampled_token_ids[:len( + flattened_indices)].squeeze(1), + non_blocking=True) + else: + # Upload the index tensors asynchronously + # so the scatter can be non-blocking + input_ids_index_tensor = torch.tensor( + flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, + non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor].squeeze(1)) + else: + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + else: + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput" ) -> tuple[PerLayerAttnMetadata, torch.Tensor, @@ -800,7 +869,8 @@ def _prepare_inputs( max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -916,6 +986,10 @@ def _prepare_inputs( builder, ) +<<<<<<< HEAD + +======= +>>>>>>> nm/sage/dbo-full-cudagraphs if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( ubatch_slices, common_attn_metadata) @@ -1563,6 +1637,7 @@ def get_dp_padding_ubatch( should_ubatch = False # Note that we compute the number of padded tokens per ubatch + (should_ubatch, num_tokens_across_dp) = self.should_ubatch_with_num_tokens( should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch) @@ -1649,7 +1724,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1852,6 +1927,12 @@ def execute_model( # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1864,21 +1945,41 @@ def execute_model( scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1886,7 +1987,12 @@ def execute_model( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1901,6 +2007,7 @@ def execute_model( start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) @@ -1922,9 +2029,9 @@ def execute_model( self.eplb_step() - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1933,6 +2040,15 @@ def execute_model( num_nans_in_logits=num_nans_in_logits, ) + if self.use_async_scheduling: + return AsyncModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + ) + + return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None @@ -2193,12 +2309,16 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.vllm_config, runtime_mode=CUDAGraphMode.FULL) elif self.parallel_config.enable_microbatching: + delayed_start = getattr(self.model, "delayed_dbo_start", False) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.FULL, self.device) + CUDAGraphMode.FULL, self.device, + delayed_start=delayed_start) else: self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.NONE, self.device) + CUDAGraphMode.NONE, self.device, + delayed_start=delayed_start) def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index affba877ecf9..99c805a3e949 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,7 +5,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -28,8 +28,8 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -355,7 +355,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 if forward_pass and not get_pp_group().is_first_rank: @@ -365,7 +365,7 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, ModelRunnerOutput): + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index fe6ae73066b3..8835c3d4dfe1 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from contextlib import contextmanager +from enum import Enum from typing import Optional - +from torch.library import Library +from functools import lru_cache import torch from vllm import forward_context @@ -12,6 +15,38 @@ _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] + +class Schedule(Enum): + # Schedule notation legend: + # S = Shared expert + # A0 = MLA qkv pro, + # A1 = Core attn + out proj + MoE gate + # D = Dispatch + # C = Combine + + # Comp: |-A0₀-A1₀-S₀-||-MLP₁-||-MLP₀-||-A0₁-A1₁-S₁-| + # Comm: |-----D₁-----||--D₀--||--C₁--||-----C₀-----| + # Order: D₁ send, A0₀, A1₀, S₀, D₁ recv, D₀ send, MLP₁, D₀ recv, + # C₁ send, MLP₀, C₁ recv, C₀ send, A0₁, A1₁, S₁, C₀ recv. + MLP_OVERLAP = "mlp_overlap" + + # Comp: |-A0₀-A1₀-||-MLP₁-||-S₁-MLP₀-||-S₀-A0₁-A1₁-| + # Comm: |----D₁---||--D₀--||----C₁---||-----C₀-----| + # Order: D₁ send, A0₀, A1₀, D₁ recv, D₀ send, MLP₁, D₀ recv, + # C₁ send, S₁, MLP₀, C₁ recv, C₀ send, S₀, A0₁, A1₁, C₀ recv. + MLP_SHARED_OVERLAP = "mlp_shared_overlap" + + # Comp: |-S₀-A0₁-|-MLP₀-|-A1₁-||-S₁-A0₀-|-MLP₁-|-A1₀-| + # Comm: |---D₀---| |-C₀--||---D₁---| |-C₁--| + # Order: D₀ send, S₀, A0₁, D₀ recv, MLP₀, C₀ send, A1₁, C₀ recv, + # D₁ send, S₁, A0₀, D₁ recv, MLP₁, C₁ send, A1₀, C₁ recv. + ATTN_SHARED_OVERLAP = "attn_shared_overlap" + +_SCHEDULE_WAIT_STAGES = { # Default is 1 + Schedule.ATTN_SHARED_OVERLAP: 2, +} + + class UBatchContext: """ Context manager for micro-batching synchronization using threading events. @@ -27,7 +62,8 @@ def __init__(self, cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event, - schedule: str = "default"): + started: bool = True, + schedule: Schedule = Schedule.MLP_OVERLAP): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream @@ -39,6 +75,7 @@ def __init__(self, self.gpu_comm_done_event = gpu_comm_done_event self.gpu_compute_done_event = gpu_compute_done_event self.schedule = schedule + self.started = started self.recv_hook = None def __enter__(self): @@ -49,6 +86,12 @@ def __enter__(self): self.cpu_wait_event.wait() self.cpu_wait_event.clear() + + if self.id > 0: + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) + for _ in range(wait_stages - 1): + self._cpu_yield(check_context=False) + self._restore_context() # Assume we start on the compute stream assert current_stream() == self.compute_stream @@ -56,10 +99,22 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT - _CURRENT_CONTEXTS[self.id] = None - del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + + if self.id == 0: + # Keep advancing the next micro-batch + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) + for _ in range(wait_stages - 1): + self.yield_() + # Cleanup and trailing recv hooks + self.maybe_run_recv_hook() + + self.maybe_run_recv_hook() self.cpu_signal_event.set() self.cpu_wait_event.clear() + + del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + _CURRENT_CONTEXTS[self.id] = None + self.current_stream = self.compute_stream torch.cuda.set_stream(self.current_stream) return False @@ -92,12 +147,17 @@ def stream_string(self): assert self.current_stream == self.comm_stream return "COMM" - def _cpu_yield(self): + def _cpu_yield(self, check_context: bool = True): # It is critical for correctness that only one thread is running # at a time. These asserts just make sure that this is the only # thread running before waking the other one up and going to sleep - assert forward_context._forward_context == self.forward_context - assert current_stream() == self.current_stream + print(f"CPU yield {self.id} {type(forward_context._forward_context)} {type(self.forward_context)}") + assert ( + not check_context or + forward_context._forward_context is self.forward_context) + assert ( + not check_context or + current_stream() == self.current_stream) assert not self.cpu_wait_event.is_set() self.cpu_signal_event.set() @@ -109,7 +169,12 @@ def switch_to_comm_sync(self): self._signal_compute_done() self.update_stream(self.comm_stream) self._wait_comm_done() - + + def switch_to_compute_sync(self): + self._signal_comm_done() + self.update_stream(self.compute_stream) + self._wait_compute_done() + def maybe_run_recv_hook(self): if self.recv_hook is not None: self.recv_hook() @@ -118,9 +183,9 @@ def maybe_run_recv_hook(self): def yield_(self): self.current_stream = current_stream() self._cpu_yield() - if self.current_stream == current_stream(): + if self.current_stream != current_stream(): self.update_stream(self.current_stream) - + def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream self._signal_compute_done() @@ -141,30 +206,82 @@ def yield_and_switch_from_comm_to_compute(self): def dbo_enabled() -> bool: return len(_THREAD_ID_TO_CONTEXT) > 0 + def dbo_current_ubatch_id() -> int: if len(_THREAD_ID_TO_CONTEXT) == 0: return 0 return _THREAD_ID_TO_CONTEXT[threading.get_ident()] -def _register_ubatch_function(func, context_offset): - def wrapper(*args, **kwargs): + +def dbo_start(): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + ctx.started = True + + +def dbo_current_schedule() -> Optional[Schedule]: + if len(_THREAD_ID_TO_CONTEXT) == 0: + return None + return _CURRENT_CONTEXTS[dbo_current_ubatch_id()].schedule + +def _register_ubatch_function(func, all_schedules_default: bool = False): + + def wrapper(schedules: tuple[Schedule] = (), + all_schedules: bool = all_schedules_default): if len(_THREAD_ID_TO_CONTEXT) > 0: - ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + context_offset + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] ctx = _CURRENT_CONTEXTS[ctx_idx] - func(ctx, *args, **kwargs) + if not ctx.started: + return + if all_schedules or ctx.schedule in schedules: + func(ctx) + return wrapper -dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(UBatchContext.yield_and_switch_from_compute_to_comm, 0) -dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(UBatchContext.yield_and_switch_from_comm_to_compute, 0) -dbo_yield = _register_ubatch_function(UBatchContext.yield_, 0) -dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook, 0) -dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync, 0) -def dbo_register_recv_hook(recv_hook): +dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( + UBatchContext.yield_and_switch_from_compute_to_comm) +dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( + UBatchContext.yield_and_switch_from_comm_to_compute) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) +# For `dbo_maybe_run_recv_hook` we usually want to run the recv hook for +# since its already conditional on a recv_hook being registered, so we default +# all_schedules to True to make the code a bit simpler. +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook, all_schedules_default=True) +dbo_switch_to_comm_sync = _register_ubatch_function( + UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute_sync = _register_ubatch_function( + UBatchContext.switch_to_compute_sync) + + +# DBO start needs to be callable from inside the torch compile region so +# we register it as a custom op. +lib = Library("vllm_dbo", "DEF") +lib.define("start(Tensor! x) -> ()") # in-place, returns x + +@torch.library.impl("vllm_dbo::start", "CompositeImplicitAutograd") +def _dbo_start_impl(x: torch.Tensor): + dbo_start() + return None + + +@lru_cache(maxsize=1) +def dbo_debug_annotate(): + return True + +def dbo_register_recv_hook(recv_hook, + schedules: tuple[Schedule] = (), + all_schedules: bool = False) -> bool: if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] - next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] - next_ctx.recv_hook = recv_hook + if all_schedules or _CURRENT_CONTEXTS[ctx_idx].schedule in schedules: + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx.recv_hook = recv_hook + return True + return False + def make_ubatch_contexts( num_micro_batches: int, @@ -172,7 +289,9 @@ def make_ubatch_contexts( comm_stream: torch.cuda.Stream, forward_contexts: list[ForwardContext], ready_barrier: threading.Barrier, - schedule: str = "default", + device: Optional[torch.device] = None, + schedule: Schedule = Schedule.MLP_OVERLAP, + delayed_start: bool = False, ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" """ @@ -200,6 +319,7 @@ def make_ubatch_contexts( num_micro_batches], gpu_comm_done_event=gpu_comm_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i], + started=not delayed_start, schedule=schedule) ctxs.append(ctx)