From 3a1d48c0e8c5467555e4d11c22d31f851ac7687d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 29 Aug 2025 17:36:59 +0000 Subject: [PATCH 1/8] shared expert overlap Signed-off-by: Sage Moore --- docs/design/fused_moe_modular_kernel.md | 8 +- examples/offline_inference/data_parallel.py | 11 +- tests/kernels/moe/test_pplx_moe.py | 86 ++++++-- tests/kernels/moe/utils.py | 149 ++++++++++++++ .../base_device_communicator.py | 7 +- .../fused_moe/deepep_ht_prepare_finalize.py | 130 +++++++++---- .../fused_moe/deepep_ll_prepare_finalize.py | 53 ++++- .../flashinfer_cutlass_prepare_finalize.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 184 ++++++++++++++---- .../layers/fused_moe/modular_kernel.py | 155 ++++++++++++--- .../layers/fused_moe/pplx_prepare_finalize.py | 77 +++++++- .../layers/fused_moe/prepare_finalize.py | 4 +- .../layers/quantization/awq_marlin.py | 4 +- .../layers/quantization/bitsandbytes.py | 2 +- .../compressed_tensors_moe.py | 12 +- .../layers/quantization/experts_int8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 4 +- .../layers/quantization/gguf.py | 4 +- .../layers/quantization/gptq_marlin.py | 2 +- .../layers/quantization/modelopt.py | 4 +- .../layers/quantization/moe_wna16.py | 4 +- .../layers/quantization/mxfp4.py | 4 +- .../layers/quantization/quark/quark_moe.py | 6 +- .../model_executor/layers/quantization/rtn.py | 4 +- .../layers/shared_fused_moe/__init__.py | 6 + .../shared_fused_moe/shared_fused_moe.py | 48 +++++ vllm/model_executor/models/deepseek_v2.py | 97 +++++---- vllm/model_executor/models/glm4_moe.py | 2 + vllm/model_executor/models/llama4.py | 25 +-- 30 files changed, 885 insertions(+), 219 deletions(-) create mode 100644 vllm/model_executor/layers/shared_fused_moe/__init__.py create mode 100644 vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 202e9c1caf11..d4edde068f49 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts ### FusedMoEPrepareAndFinalize -The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions. -The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) +The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. +The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" thunk that must be called to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") @@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. + +`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. + `FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 8453f35068ec..2c7b81e42e41 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -90,7 +90,13 @@ def parse_args(): parser.add_argument( "--enable-microbatching", action="store_true", - help=("Enable microbatched execution"), + help=("Enable microbatched execution") + ) + parser.add_argument( + "--compilation-config", + type=int, + default=0, + help=("Compilation optimization (O) level 0-3."), ) parser.add_argument( "--quantization", @@ -111,6 +117,7 @@ def main( trust_remote_code, max_num_seqs, max_model_len, + compilation_config, gpu_memory_utilization, enable_microbatching, quantization, @@ -169,6 +176,7 @@ def start(rank): gpu_memory_utilization=gpu_memory_utilization, enable_microbatching=enable_microbatching, quantization=quantization, + compilation_config=compilation_config, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -225,6 +233,7 @@ def start(rank): args.trust_remote_code, args.max_num_seqs, args.max_model_len, + args.compilation_config, args.gpu_memory_utilization, args.enable_microbatching, args.quantization, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 3f36d7ada2e9..394f52114085 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,10 +4,11 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ +import copy import itertools import textwrap import traceback -from typing import Callable, Optional +from typing import Callable, Optional, Union import pytest import torch @@ -21,7 +22,10 @@ except ImportError: has_pplx = False -from tests.kernels.moe.utils import make_test_weights, naive_batched_moe +from tests.kernels.moe.modular_kernel_tools.parallel_utils import ( + _set_vllm_config) +from tests.kernels.moe.utils import (make_shared_experts, make_test_weights, + naive_batched_moe) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config @@ -511,7 +515,8 @@ def pplx_moe( block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, -) -> torch.Tensor: + shared_experts: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] @@ -546,6 +551,7 @@ def pplx_moe( fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + shared_experts, ) # Note: workers with the same dp_rank must use the exact same inputs. @@ -586,7 +592,11 @@ def pplx_moe( global_num_experts=num_experts) if use_cudagraphs: - out.fill_(0) + if isinstance(out, tuple): + out[0].fill_(0) + out[1].fill_(0) + else: + out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): @@ -626,6 +636,7 @@ def _pplx_moe( per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, use_internode: bool = False, + shared_experts: Optional[torch.nn.Module] = None, ): try: if use_internode: @@ -666,6 +677,11 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + if shared_experts is not None: + shared_output = shared_experts(a) + else: + shared_output = None + torch_output = torch_experts( a, w1, @@ -696,7 +712,7 @@ def _pplx_moe( block_shape=block_shape, ) - pplx_output = pplx_moe( + pplx_outputs = pplx_moe( group_name, rank, world_size, @@ -713,8 +729,24 @@ def _pplx_moe( quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, + shared_experts=shared_experts, ) + if shared_experts is None: + pplx_shared_output = None + pplx_output = pplx_outputs + assert isinstance(pplx_output, torch.Tensor) + else: + pplx_shared_output, pplx_output = pplx_outputs + + if shared_output is not None: + assert pplx_shared_output is not None + chunked_shared_output = chunk_by_rank( + shared_output, pgi.rank, + pgi.world_size).to(pplx_shared_output.device) + else: + chunked_shared_output = None + chunked_batch_output = chunk_by_rank( batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -727,6 +759,15 @@ def _pplx_moe( chunked_batch_output, atol=3e-2, rtol=3e-2) + + if shared_experts is not None: + assert chunked_shared_output is not None + assert pplx_shared_output is not None + torch.testing.assert_close(pplx_shared_output, + chunked_shared_output, + atol=3e-2, + rtol=3e-2) + finally: if use_internode: nvshmem_finalize() @@ -788,7 +829,8 @@ def test_pplx_moe_slow( def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - make_weights: bool, test_fn: Callable): + use_shared_experts: bool, make_weights: bool, + test_fn: Callable): def format_result(msg, ex=None): if ex is not None: @@ -803,6 +845,14 @@ def format_result(msg, ex=None): else: print(f"PASSED {msg}") + if use_shared_experts: + # Note: this config is only needed for the non-naive shared experts. + new_vllm_config = copy.deepcopy(vllm_config) + new_vllm_config.parallel_config.data_parallel_size = pgi.world_size + new_vllm_config.parallel_config.enable_expert_parallel = True + _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, + pgi.local_rank) + current_platform.seed_everything(7) combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]) @@ -819,9 +869,11 @@ def format_result(msg, ex=None): use_fp8_w8a8 = False quant_dtype = None - test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " - f"dtype={dtype}, per_act_token={per_act_token_quant}, " - f"block_shape={block_shape}") + test_desc = ( + f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}, use_internode={use_internode}, " + f"use_shared_experts={use_shared_experts}") if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): @@ -852,6 +904,14 @@ def format_result(msg, ex=None): args["w1_s"] = w1_s args["w2_s"] = w2_s + if use_shared_experts: + args["shared_experts"] = make_shared_experts( + n, + k, + in_dtype=a.dtype, + quant_dtype=quant_dtype, + ) + try: test_fn( pgi=pgi, @@ -891,18 +951,20 @@ def test_pplx_prepare_finalize( current_platform.seed_everything(7) world_size, dp_size = world_dp_size parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, _pplx_prepare_finalize) + use_internode, False, False, _pplx_prepare_finalize) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.parametrize("use_shared_experts", [False, True]) @requires_pplx @multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, + use_shared_experts: bool, ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, - _pplx_moe) + parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, + use_shared_experts, True, _pplx_moe) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 82960bd57345..4b58a28eed12 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -8,6 +8,7 @@ from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX) +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) @@ -282,3 +283,151 @@ def per_token_cast_to_fp8( x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +# CustomOp? +class BaselineMM(torch.nn.Module): + + def __init__( + self, + b: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.b = b.to(dtype=torch.float32) + self.out_dtype = out_dtype + + def forward( + self, + a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return torch.mm(a.to(dtype=torch.float32), + self.b).to(self.out_dtype), None + + +class TestMLP(torch.nn.Module): + + def __init__( + self, + w1: torch.Tensor, + w2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.gate_up_proj = BaselineMM(w1, out_dtype) + self.down_proj = BaselineMM(w2, out_dtype) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def make_naive_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, +) -> torch.nn.Module: + w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15 + w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15 + return TestMLP(w1, w2, out_dtype=in_dtype) + + +class RealMLP(torch.nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + w1: torch.Tensor, + w2: torch.Tensor, + hidden_act: str = "silu", + quant_config=None, + reduce_results: bool = True, + prefix: str = "", + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + ) -> None: + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, RowParallelLinear) + + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.gate_up_proj.register_parameter( + "weight", torch.nn.Parameter(w1, requires_grad=False)) + self.gate_up_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)) + self.gate_up_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + self.down_proj.register_parameter( + "weight", torch.nn.Parameter(w2, requires_grad=False)) + self.down_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)) + self.down_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def make_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Union[torch.dtype, str, None] = None, +) -> torch.nn.Module: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + 1, + N, + K, + in_dtype=in_dtype, + quant_dtype=quant_dtype, + ) + old_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(in_dtype) + if quant_dtype == torch.float8_e4m3fn: + w1 = w1[0].transpose(0, 1) + w2 = w2[0].transpose(0, 1) + w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None + w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None + quant_config = Fp8Config(True) + else: + w1 = w1[0] + w2 = w2[0] + w1_s = None + w2_s = None + quant_config = None + + return RealMLP(K, + N, + w1, + w2, + "silu", + quant_config, + w1_s=w1_s, + w2_s=w2_s) + finally: + torch.set_default_dtype(old_dtype) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 5b2c571afc98..2a63ee407b1d 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -255,10 +255,13 @@ def prepare_communication_buffer_for_model(self, moe_modules = [ module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" + # TODO(bnell): Should use isinstance but can't. Maybe search for + # presence of quant_method.init_prepare_finalize? + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] for module in moe_modules: - module.quant_method.init_prepare_finalize() + module.quant_method.init_prepare_finalize(module) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index bfdea93669e3..89bd322c6f80 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Callable, Optional, Union import deep_ep import torch @@ -28,6 +28,8 @@ def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset + self.async_prepare = True + # The dispatch function returns a handle that the combine function # requires. Under DBO microbatching we must track one handle per # micro-batch to avoid races between threads. @@ -59,10 +61,16 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_dispatch(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, num_experts: int): + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: has_scales = token_scales is not None @@ -97,7 +105,7 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=False, + async_finish=self.async_prepare, allocate_on_comm_stream=False) dbo_yield_and_switch_from_comm_to_compute() @@ -105,6 +113,33 @@ def _do_dispatch(self, tokens: torch.Tensor, a2a_idx = dbo_current_ubatch_id() self.handles[a2a_idx] = handle + return lambda: self._receiver( + event, + has_scales, + token_data, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if self.async_prepare: + event.current_stream_wait() + if has_scales: expert_x, expert_x_scale = token_data else: @@ -121,6 +156,7 @@ def _do_dispatch(self, tokens: torch.Tensor, # DeepEP's topk_ids output refers to the local experts directly. Offset # the topk_ids to move it back to the global experts space so it aligns # with existing vLLM interfaces. + assert expert_topk_ids is not None expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, @@ -132,10 +168,28 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_num_tokens_per_expert_list, device=expert_x.device) + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: + # Quantize after dispatch. + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -146,9 +200,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> Callable: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -168,37 +220,37 @@ def prepare( ) if a1q_scale is not None and a1q_scale.numel() == 1: a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) + a1_post_scale = None else: - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 - (expert_x, _, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1, - token_scales=None, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) - # Quantize after dispatch. - expert_x_scale = None - if expert_x.numel() != 0: - expert_x, expert_x_scale = moe_kernel_quantize_input( - expert_x, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=False, - block_shape=quant_config.block_shape) + a1q = a1 + a1q_scale = None + a1_post_scale = a1_scale - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 6568a1cd31fe..51a09d4890f4 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import deep_ep import torch @@ -80,7 +80,6 @@ def _do_quant( self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, quant_dtype: Union[torch.dtype, str, None], per_act_token_quant: bool, @@ -115,7 +114,10 @@ def _do_quant( return x, x_scales - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -126,13 +128,11 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: hidden_size = a1.size(1) a2a_idx = dbo_current_ubatch_id() - do_recv_hook = dbo_enabled() + # do_recv_hook = dbo_enabled() if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -160,20 +160,53 @@ def prepare( num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, - return_recv_hook=do_recv_hook) + return_recv_hook=True) self.handles[a2a_idx] = handle if recv_hook is not None: dbo_register_recv_hook(recv_hook) dbo_yield() + return (recv_hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens, + a1_scale, a1.dtype, quant_config)) + + def _receiver( + self, + hook: Optional[Callable], + expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_num_tokens: torch.Tensor, + a1_scale, + a1_dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if hook is not None: + hook() + expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 061b02172c44..157cb36d4ffd 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -56,9 +56,7 @@ def prepare( apply_router_weight_on_input: bool, # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b46f4be4b912..88063668e918 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -506,9 +506,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 919a59a63e4c..4246b7ad69b6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -201,7 +201,7 @@ def maybe_make_prepare_finalize( # Note: init_prepare_finalize should only be called by # prepare_communication_buffer_for_model. - def init_prepare_finalize(self): + def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None prepare_finalize = self.maybe_make_prepare_finalize(self.moe) @@ -216,6 +216,7 @@ def init_prepare_finalize(self): self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( @@ -251,7 +252,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -406,7 +407,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -456,7 +457,7 @@ def forward_cuda( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -540,7 +541,7 @@ def forward_cpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -585,7 +586,7 @@ def forward_xpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -623,7 +624,7 @@ def forward_tpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -931,6 +932,10 @@ def __init__( dtype=moe.in_dtype, device=torch.cuda.current_device()) + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -1563,25 +1568,45 @@ def maybe_all_reduce_tensor_model_parallel( else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: hidden_states = F.pad(hidden_states, (0, self.hidden_size - og_hidden_states), mode='constant', value=0.0) - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) + + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name) + return fused_output[..., :og_hidden_states] else: - return torch.ops.vllm.moe_forward( - hidden_states, router_logits, - self.layer_name)[..., :og_hidden_states] + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return (shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states]) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -1592,7 +1617,10 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) - full_final_hidden_states = torch.empty_like(full_hidden_states) + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like( + full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start @@ -1641,9 +1669,21 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): logical_replica_count=self.logical_replica_count, ) + assert self.shared_experts is None or isinstance( + final_hidden_states, tuple) + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states, + non_blocking=True) + else: + full_shared_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[0], + non_blocking=True) + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[1], + non_blocking=True) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP @@ -1664,10 +1704,17 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_end, skip_result_store=chunk_start_ >= num_tokens) - return full_final_hidden_states + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, + full_fused_final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. @@ -1687,6 +1734,15 @@ def forward_impl(self, hidden_states: torch.Tensor, hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if (not isinstance(self.quant_method.fused_experts, + FusedMoEModularKernel) + and self.shared_experts is not None): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1710,14 +1766,30 @@ def forward_impl(self, hidden_states: torch.Tensor, logical_replica_count=self.logical_replica_count, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) - return final_hidden_states + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is None: + return reduce_output(final_hidden_states) + else: + return ( + reduce_output(final_hidden_states[0]), + reduce_output(final_hidden_states[1]), + ) @classmethod def make_expert_params_mapping( @@ -1772,17 +1844,22 @@ def extra_repr(self) -> str: return s -def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - assert self.quant_method is not None - + assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) -def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1795,6 +1872,37 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, tags=(torch.Tag.needs_fixed_stride_order, ), ) + +def moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.shared_experts is not None + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = torch.empty_like(hidden_states) + fused_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=moe_forward_shared, + mutates_args=["hidden_states"], + fake_impl=moe_forward_shared_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + # Mark the FusedMoE weight_loader as supporting MoE-specific parameters # to avoid expensive runtime reflection in model loading code FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a96cc9520328..dcdefc06a27e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Callable, Optional, Union, final import torch @@ -13,6 +13,11 @@ from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv +from vllm.v1.worker.ubatching import (dbo_enabled, + dbo_current_ubatch_id, + dbo_yield, + dbo_maybe_run_recv_hook, + dbo_register_recv_hook) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -141,6 +146,29 @@ def apply(self, output: Optional[torch.Tensor], raise NotImplementedError +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - Optional ExpertTokensMetadata containing gpu/cpu tensors +# as big as the number of local experts with the information about the +# number of tokens assigned to each local expert. +# - Optional dispatched expert topk IDs +# - Optional dispatched expert topk weight +# +# See `prepare` method below. +# +PrepareResultType = tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], +] + +ReceiverType = Callable[[], PrepareResultType] + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -160,16 +188,9 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> PrepareResultType: """ - Perform any quantization (and/or) dispatching needed - for this kernel. + Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make @@ -193,6 +214,51 @@ def prepare( """ raise NotImplementedError + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async. + """ + return False + + def prepare_async( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> ReceiverType: + """ + Perform any quantization (and/or) dispatching needed for this kernel + but do not wait for results from other workers. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `prepare`, e.g. + + receiver = obj.prepare_async(...) + a, a_scales, expert_meta, topk_ids, topk_weights = receiver() + + is equivalent to: + + a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) + """ + raise NotImplementedError + @abstractmethod def finalize( self, @@ -473,10 +539,12 @@ def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, + shared_experts: Optional[torch.nn.Module] = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + self.shared_experts = shared_experts assert prepare_finalize.activation_format == \ fused_experts.activation_formats[0], ( f"{prepare_finalize.__class__.__name__}." @@ -715,7 +783,7 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -759,18 +827,54 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + shared_output: torch.Tensor + + if (not self.prepare_finalize.supports_async() + or self.shared_experts is None): + assert False + + # Run shared experts serially with dispatch. + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + dbo_maybe_run_recv_hook() + hook, receiver = self.prepare_finalize.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + assert self.shared_experts is not None + shared_output = self.shared_experts(a1) + + dbo_register_recv_hook(hook) + dbo_yield() + + if dbo_enabled(): + hook = None + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = receiver(hook) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -818,4 +922,7 @@ def forward( self.fused_experts.finalize_weight_and_reduce_impl(), ) - return output + if self.shared_experts is None: + return output + else: + return shared_output, output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index f64dd94ecb6e..1d9f49759a83 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -87,12 +87,15 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -103,9 +106,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.ReceiverType: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K a2a_idx = dbo_current_ubatch_id() @@ -142,6 +143,8 @@ def prepare( _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape) + orig_a_scale_block_shape: Optional[int] = None + if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -210,8 +213,45 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + return lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + a1q, + a1q_scale, + topk_ids, + bound_m, + orig_a_scale_block_shape, + ) + + def _receiver( + self, + expert_num_tokens: torch.Tensor, + expert_x: torch.Tensor, + expert_x_scale: Optional[torch.Tensor], + a1q: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + bound_m: Optional[torch.Tensor], + orig_a_scale_block_shape: Optional[int], + ) -> mk.PrepareResultType: + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, bound_m=bound_m, + do_send=False, + do_recv=True, ) dbo_yield_and_switch_from_comm_to_compute() @@ -224,6 +264,31 @@ def prepare( return expert_x, expert_x_scale, expert_tokens_meta, None, None + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + return receiver() + def finalize( self, output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 567a0a88fec0..bd9f7d4a06b1 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -38,9 +38,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 287d66b06d6e..a4dd4f59c77f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from torch.nn import Parameter @@ -504,7 +504,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index b7897a43793c..1f9f95d06667 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -473,7 +473,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm.model_executor.layers.fused_moe import fused_experts assert self.fused_experts is None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6279bb8b6057..df20300cff2c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,7 +3,7 @@ import enum from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat @@ -356,7 +356,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -816,7 +816,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -1064,7 +1064,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1368,7 +1368,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1599,7 +1599,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 3e43caa4cbf7..188b2ac5c274 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -127,7 +127,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index be358cfa949f..790856ed2bb8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -961,7 +961,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 90222f2e3b0e..f4d828dc43c7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import gguf import torch @@ -539,7 +539,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c5d1e017014f..d2a0ac30a0e4 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -650,7 +650,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 72864853f7e0..fefc7ea05f8d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -489,7 +489,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") @@ -1359,7 +1359,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 364d1ac314d2..15a7c8589e72 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -306,7 +306,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bdeb169a4b97..d393690e0a7c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from torch.nn.parameter import Parameter @@ -466,7 +466,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58f56c6381b3..446928c19ac8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -225,7 +225,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -387,7 +387,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 8bdb50e07b13..24d929d5940c 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,7 +3,7 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -290,7 +290,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/shared_fused_moe/__init__.py b/vllm/model_executor/layers/shared_fused_moe/__init__.py new file mode 100644 index 000000000000..b87c69d3edd0 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import ( + SharedFusedMoE) + +__all__ = ["SharedFusedMoE"] diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py new file mode 100644 index 000000000000..ad7277425ef0 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + shared_out = self._shared_experts(hidden_states) + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7657e7cb003d..cf51b82903d0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -48,6 +48,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -146,61 +147,81 @@ def __init__( self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + self.shared_experts = None + else: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output if self.tp_size > 1: final_hidden_states = ( diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index fe5e46a99826..fe467e1e5517 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -181,6 +181,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) + else: + shared_output = None router_logits = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts( hidden_states=hidden_states, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ba08e6f81f7f..d28a8623272c 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -36,6 +36,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -73,7 +74,18 @@ def __init__(self, quant_config=None, prefix=f"{prefix}.router") - self.experts = FusedMoE( + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -83,16 +95,7 @@ def __init__(self, reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.shared_expert = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size_moe, - hidden_act="silu", - quant_config=quant_config, - bias=False, - prefix=f"{prefix}.shared_expert", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.experts", ) def forward(self, hidden_states): From 42a5e5537c9c8f407a3fec69c1fd68232438ac10 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 29 Aug 2025 17:43:13 +0000 Subject: [PATCH 2/8] minor fix Signed-off-by: Sage Moore --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 51a09d4890f4..610e37cc1a0f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -152,7 +152,7 @@ def prepare_async( a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - dbo_maybe_run_recv_hook() + # dbo_maybe_run_recv_hook() expert_x, expert_num_tokens, handle, _, recv_hook= \ self.buffers[a2a_idx].low_latency_dispatch(a1, topk_ids, @@ -162,9 +162,9 @@ def prepare_async( async_finish=False, return_recv_hook=True) self.handles[a2a_idx] = handle - if recv_hook is not None: - dbo_register_recv_hook(recv_hook) - dbo_yield() + # if recv_hook is not None: + # dbo_register_recv_hook(recv_hook) + # dbo_yield() return (recv_hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens, a1_scale, a1.dtype, quant_config)) From c42f57aba9c7512664838854b42a108098e7c03d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 28 Aug 2025 15:21:55 -0700 Subject: [PATCH 3/8] [Perf][V1] Fully overlap model execution Co-authored-by: Benjamin Chislett Signed-off-by: Nick Hill --- vllm/v1/executor/multiproc_executor.py | 28 ++++- vllm/v1/outputs.py | 28 +++++ vllm/v1/worker/gpu_input_batch.py | 5 + vllm/v1/worker/gpu_model_runner.py | 153 +++++++++++++++++++++---- vllm/v1/worker/gpu_worker.py | 10 +- 5 files changed, 195 insertions(+), 29 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 12e79ff165f4..e53f3ce099ff 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -3,6 +3,7 @@ import multiprocessing import os import pickle +import queue import signal import threading import time @@ -18,6 +19,7 @@ from typing import Any, Callable, Optional, Union, cast import cloudpickle +import torch import vllm.envs as envs from vllm.config import VllmConfig @@ -33,7 +35,8 @@ get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, + ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -412,6 +415,14 @@ def __init__( # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_stream = torch.cuda.Stream() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy") + self.async_output_copy_thread.start() + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -593,6 +604,18 @@ class ResponseStatus(Enum): SUCCESS = auto() FAILURE = auto() + def enqueue_worker_output(self, output: Any) -> None: + if isinstance(output, AsyncModelRunnerOutput): + output = output.serialize(self.async_output_copy_stream) + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_worker_output(output) + def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: @@ -617,5 +640,4 @@ def worker_busy_loop(self): continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.async_output_queue.put(output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3..22da4b34e554 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -114,6 +114,34 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +# Contains GPU tensors which must be serialized before sending +# to the scheduler process. +@dataclass +class AsyncModelRunnerOutput: + model_runner_output: ModelRunnerOutput + + # [num_reqs, max_num_generated_tokens] + sampled_token_ids: torch.Tensor + + invalid_req_indices: list[int] + + def serialize(self, copy_stream: torch.cuda.Stream) -> ModelRunnerOutput: + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(copy_stream): + copy_stream.wait_stream(default_stream) + sampled_token_ids_cpu = self.sampled_token_ids.to( + 'cpu', non_blocking=True) + copy_stream.synchronize() + valid_sampled_token_ids = sampled_token_ids_cpu.tolist() + for i in self.invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self.model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + @dataclass class DraftTokenIds: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f4c2f45df595..993210c244bb 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -250,6 +250,11 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 174129222183..38975adca378 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -68,8 +68,8 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - LogprobsTensors, ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -243,6 +243,8 @@ def __init__( is_pooling_model=self.is_pooling_model, ) + self.use_async_scheduling = self.scheduler_config.async_scheduling + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -718,6 +720,73 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is not None: + # Async scheduling case, we need to copy the sampled token ids + # from the previous iteration. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + current_req_id_to_index = self.input_batch.req_id_to_index + assert prev_req_id_to_index is not None + common_req_ids = set(prev_req_id_to_index.keys()).intersection( + set(current_req_id_to_index.keys())) + if common_req_ids: + current_common_req_indices = [ + current_req_id_to_index[req_id] + for req_id in common_req_ids + ] + prev_common_req_indices = [ + prev_req_id_to_index[req_id] for req_id in common_req_ids + ] + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_indices = [ + int(cu_num_tokens[idx]) - 1 + for idx in current_common_req_indices + ] + if len(flattened_indices) < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if flattened_indices == prev_common_req_indices and \ + set(flattened_indices) == \ + set(range(len(flattened_indices))): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 + self.input_ids.gpu[:len(flattened_indices)].copy_( + self.input_batch.prev_sampled_token_ids[:len( + flattened_indices)].squeeze(1), + non_blocking=True) + else: + # Upload the index tensors asynchronously + # so the scatter can be non-blocking + input_ids_index_tensor = torch.tensor( + flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, + non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor].squeeze(1)) + else: + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + else: + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput" ) -> tuple[PerLayerAttnMetadata, torch.Tensor, @@ -809,7 +878,8 @@ def _prepare_inputs( max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -1679,7 +1749,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1891,6 +1961,12 @@ def execute_model( # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1903,21 +1979,41 @@ def execute_model( scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1925,7 +2021,12 @@ def execute_model( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1940,6 +2041,7 @@ def execute_model( start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) @@ -1961,9 +2063,9 @@ def execute_model( self.eplb_step() - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1972,6 +2074,15 @@ def execute_model( num_nans_in_logits=num_nans_in_logits, ) + if self.use_async_scheduling: + return AsyncModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + ) + + return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2088bfff5bb3..f4077bbd6492 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,7 +5,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -28,8 +28,8 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -352,7 +352,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 if forward_pass and not get_pp_group().is_first_rank: @@ -362,7 +362,7 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, ModelRunnerOutput): + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors) From 2acb1f5073a51d07da2c9b3a79bcdb19944ae5a7 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 29 Aug 2025 22:13:24 +0000 Subject: [PATCH 4/8] nccl backed comms for ubatch coordination and padding bugfix Signed-off-by: Sage Moore --- vllm/forward_context.py | 27 ++++++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 17 +++++++++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd25c272e034..e29201b223cf 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -93,27 +93,32 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, return num_tokens_tensor @staticmethod - def should_ubatch_across_dp(should_ubatch: bool, num_tokens_per_ubatch: int, dp_size: int, + def should_ubatch_across_dp(should_ubatch: bool, orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, dp_size: int, dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]: - tensor = torch.zeros(3, dp_size, device="cpu", dtype=torch.int32) - tensor[0][dp_rank] = num_tokens_per_ubatch - tensor[1][dp_rank] = 1 if should_ubatch else 0 + tensor = torch.zeros(3, dp_size, device="cuda", dtype=torch.int32) + tensor[0][dp_rank] = orig_num_tokens_per_ubatch + tensor[1][dp_rank] = padded_num_tokens_per_ubatch + tensor[2][dp_rank] = 1 if should_ubatch else 0 from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(tensor, group=get_dp_group().cpu_group) + dist.all_reduce(tensor, group=get_dp_group().device_group) - result: bool = bool(torch.all(tensor[1]== 1).item()) + result: bool = bool(torch.all(tensor[2]== 1).item()) if not result: return result, None - min_num_tokens_per_ubatch = tensor[0].min().item() - max_num_tokens_per_ubatch = tensor[0].max().item() - if max_num_tokens_per_ubatch >= 2 * min_num_tokens_per_ubatch: - logger.debug(f"Aborting ubatching {min_num_tokens_per_ubatch} {max_num_tokens_per_ubatch}") + orig_num_tokens_tensor = tensor[0, :] + padded_num_tokens_tensor = tensor[1, :] + + orig_min_num_tokens = orig_num_tokens_tensor.min().item() + padded_max_num_tokens = padded_num_tokens_tensor.max().item() + if padded_max_num_tokens >= 2 * orig_min_num_tokens: + logger.debug(f"Aborting ubatching {orig_min_num_tokens} {padded_max_num_tokens}") return False, None - return result, tensor[0, :] + return result, padded_num_tokens_tensor @staticmethod def make( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 38975adca378..b669020db264 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1608,7 +1608,7 @@ def get_dp_padding_ubatch( return False, 0, None if ubatch_slices is None: - (should_ubatch, num_tokens_across_dp) = self.should_ubatch_with_num_tokens(False, 0) + (should_ubatch, num_tokens_across_dp) = self.should_ubatch_with_num_tokens(False, 0, 0) assert should_ubatch is False assert num_tokens_across_dp is None return should_ubatch, 0, num_tokens_across_dp @@ -1651,7 +1651,7 @@ def get_dp_padding_ubatch( # Note that we compute the number of padded tokens per ubatch should_ubatch, num_tokens_across_dp= self.should_ubatch_with_num_tokens(should_ubatch, - num_tokens_per_ubatch) + num_tokens_unpadded // 2, num_tokens_per_ubatch) if not should_ubatch: assert num_tokens_across_dp is None return should_ubatch, 0, num_tokens_across_dp @@ -1675,7 +1675,7 @@ def get_dp_padding_ubatch( def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices, num_pad_tokens: int): original_num_tokens = ubatch_slices[1].token_slice.stop - assert num_pad_tokens < original_num_tokens + assert num_pad_tokens < original_num_tokens, f"num_pad_tokens {num_pad_tokens} original_num_tokens {original_num_tokens}" total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2 padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) @@ -1699,11 +1699,16 @@ def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice) - def should_ubatch_with_num_tokens(self, should_ubatch: bool, num_tokens_per_ubatch: int, + def should_ubatch_with_num_tokens(self, should_ubatch: bool, orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, ) -> tuple[bool, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank - return DPMetadata.should_ubatch_across_dp(should_ubatch, num_tokens_per_ubatch, dp_size, dp_rank) + return DPMetadata.should_ubatch_across_dp(should_ubatch, + orig_num_tokens_per_ubatch, + padded_num_tokens_per_ubatch, + dp_size, + dp_rank) def _pool( self, @@ -2572,7 +2577,7 @@ def _dummy_run( should_ubatch = num_tokens >= \ self.parallel_config.microbatching_token_threshold and \ allow_microbatching - should_ubatch, _ = self.should_ubatch_with_num_tokens(should_ubatch, num_tokens // 2,) + should_ubatch, _ = self.should_ubatch_with_num_tokens(should_ubatch, num_tokens // 2, num_tokens // 2,) assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } From 0d7e462205cb17db7dc16325e11108e4968c33cb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 31 Aug 2025 03:37:37 +0000 Subject: [PATCH 5/8] wip alt schedule Signed-off-by: Lucas Wilkinson --- vllm/compilation/ubatch_wrapper.py | 123 +++++++++--------- .../fused_moe/deepep_ll_prepare_finalize.py | 83 ++++++------ .../layers/fused_moe/modular_kernel.py | 57 ++++---- .../layers/fused_moe/pplx_prepare_finalize.py | 92 +++++++------ vllm/model_executor/models/deepseek_v2.py | 99 ++++++-------- vllm/v1/attention/backends/mla/common.py | 16 ++- vllm/v1/worker/ubatching.py | 103 ++++++++++++--- 7 files changed, 304 insertions(+), 269 deletions(-) diff --git a/vllm/compilation/ubatch_wrapper.py b/vllm/compilation/ubatch_wrapper.py index eba5e81c90d5..5bc3df813395 100644 --- a/vllm/compilation/ubatch_wrapper.py +++ b/vllm/compilation/ubatch_wrapper.py @@ -2,32 +2,26 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from contextlib import ExitStack -from typing import Any, Callable, Optional -from unittest.mock import patch import threading +from typing import Any, Callable, Optional import torch import vllm.envs as envs -from vllm.compilation.counter import compilation_counter -from vllm.compilation.monitor import validate_cudagraph_capturing_enabled +from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig -from vllm.forward_context import (BatchDescriptor, get_forward_context, - create_forward_context, +from vllm.distributed.parallel_state import is_global_first_rank +from vllm.forward_context import (create_forward_context, get_forward_context, override_forward_context) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors -from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from vllm.sequence import IntermediateTensors -from vllm.compilation.cuda_graph import CUDAGraphWrapper -from vllm.distributed.parallel_state import is_global_first_rank - - +from vllm.v1.worker.ubatching import (Schedule, UBatchContext, + make_ubatch_contexts) logger = init_logger(__name__) + @dataclasses.dataclass class UbatchMetadata: context: UBatchContext @@ -37,19 +31,18 @@ class UbatchMetadata: intermediate_tensors: Optional[IntermediateTensors] num_tokens: int + @dataclasses.dataclass class CUDAGraphMetaData: cudagraph: torch.cuda.CUDAGraph ubatch_metadata: UbatchMetadata outputs: Optional[Any] = None + class UBatchWrapper: - def __init__(self, - runnable: Callable, - vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, - device: torch.cuda.device): + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, device: torch.cuda.device): self.runnable = runnable self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -62,7 +55,8 @@ def __init__(self, self.cudagraph_wrapper = None self.graph_pool = None if runtime_mode is not CUDAGraphMode.NONE: - self.cudagraph_wrapper = CUDAGraphWrapper(runnable, vllm_config, runtime_mode=runtime_mode) + self.cudagraph_wrapper = CUDAGraphWrapper( + runnable, vllm_config, runtime_mode=runtime_mode) self.graph_pool = current_platform.get_global_graph_pool() self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" @@ -115,7 +109,7 @@ def _capture_ubatch_thread(results, ubatch_metadata): )) ubatch_threads.append(thread) thread.start() - self.ready_barrier.wait() # Wait for both threads to be ready + self.ready_barrier.wait() # Wait for both threads to be ready # DO capture cudagraph_metadata = \ @@ -164,18 +158,18 @@ def _ubatch_thread(results, model, ubatch_metadata): )) ubatch_threads.append(thread) thread.start() - self.ready_barrier.wait() # Wait for both threads to be ready + self.ready_barrier.wait() # Wait for both threads to be ready ubatch_metadata[0].context.cpu_wait_event.set() for thread in ubatch_threads: thread.join() sorted_results = [value for position, value in sorted(results)] result = torch.cat(sorted_results, dim=0) return result - - def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, - input_ids, positions, inputs_embeds, - intermediate_tensors, compute_stream, - num_tokens_across_dp, batch_descriptor, + + def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, + positions, inputs_embeds, intermediate_tensors, + compute_stream, num_tokens_across_dp, + batch_descriptor, cudagraph_runtime_mode) -> list[UbatchMetadata]: # Create one forward context per ubatch @@ -199,7 +193,8 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, forward_contexts=forward_contexts, ready_barrier=self.ready_barrier, device=self.device, - enable_async_comms=self.vllm_config.parallel_config.enable_async_comms) + schedule=Schedule.MLP_OVERLAP, + ) ubatch_metadata: list[UbatchMetadata] = [] for i, ubatch_slice in enumerate(ubatch_slices): @@ -209,28 +204,31 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, ubatch_slice.token_slice, input_ids, positions, inputs_embeds, intermediate_tensors) ubatch_metadata.append( - UbatchMetadata(context=ubatch_ctxs[i], - input_ids=sliced_input_ids, - positions=sliced_positions, - inputs_embeds=sliced_inputs_embeds, - intermediate_tensors=sliced_intermediate_tensors, - num_tokens=ubatch_slice.token_slice.stop - - ubatch_slice.token_slice.start)) + UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=sliced_input_ids, + positions=sliced_positions, + inputs_embeds=sliced_inputs_embeds, + intermediate_tensors=sliced_intermediate_tensors, + num_tokens=ubatch_slice.token_slice.stop - + ubatch_slice.token_slice.start)) return ubatch_metadata - - def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, inputs_embeds, intermediate_tensors): + def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, + inputs_embeds, intermediate_tensors): sliced_input_ids = input_ids[tokens_slice] # if we are using mrope if positions.ndim == 2: sliced_positions = positions[:, tokens_slice] else: sliced_positions = positions[tokens_slice] - sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None - sliced_intermediate_tensors = intermediate_tensors[tokens_slice] if intermediate_tensors else None + sliced_inputs_embeds = inputs_embeds[ + tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = intermediate_tensors[ + tokens_slice] if intermediate_tensors else None - return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, + return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, sliced_intermediate_tensors) def __call__(self, *args, **kwargs): @@ -248,7 +246,8 @@ def __call__(self, *args, **kwargs): return self.cudagraph_wrapper(*args, **kwargs) attn_metadata = forward_context.attn_metadata - num_tokens = (ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start) * 2 + num_tokens = (ubatch_slices[0].token_slice.stop - + ubatch_slices[0].token_slice.start) * 2 input_ids = kwargs['input_ids'] positions = kwargs['positions'] intermediate_tensors = kwargs['intermediate_tensors'] @@ -266,16 +265,16 @@ def __call__(self, *args, **kwargs): if is_global_first_rank(): logger.debug(f"CAPTURING CUDAGRAPH {num_tokens}") ubatch_metadata = self._make_ubatch_metadata( - ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - compute_stream=compute_stream, - num_tokens_across_dp=num_tokens_across_dp, - batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) return self._capture_ubatches(ubatch_metadata, self.model) elif num_tokens in self.cudagraphs: @@ -286,16 +285,18 @@ def __call__(self, *args, **kwargs): return cudagraph_metadata.outputs else: if is_global_first_rank(): - logger.debug(f"RUNNING UBATCHED {num_tokens} CUDAGRAPH MODE {cudagraph_runtime_mode}") + logger.debug( + f"RUNNING UBATCHED {num_tokens} CUDAGRAPH MODE {cudagraph_runtime_mode}" + ) ubatch_metadata = self._make_ubatch_metadata( - ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - compute_stream=compute_stream, - num_tokens_across_dp=num_tokens_across_dp, - batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 610e37cc1a0f..4b4c0cf4e922 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from typing import Optional, Union import deep_ep import torch @@ -11,11 +11,9 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) -from vllm.v1.worker.ubatching import (dbo_enabled, - dbo_current_ubatch_id, - dbo_yield, +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook, - dbo_register_recv_hook) + dbo_register_recv_hook, dbo_yield) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -132,7 +130,6 @@ def prepare_async( hidden_size = a1.size(1) a2a_idx = dbo_current_ubatch_id() - # do_recv_hook = dbo_enabled() if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -152,8 +149,7 @@ def prepare_async( a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - # dbo_maybe_run_recv_hook() - expert_x, expert_num_tokens, handle, _, recv_hook= \ + _expert_x, expert_num_tokens, handle, _, recv_hook= \ self.buffers[a2a_idx].low_latency_dispatch(a1, topk_ids, self.max_tokens_per_rank, @@ -162,33 +158,19 @@ def prepare_async( async_finish=False, return_recv_hook=True) self.handles[a2a_idx] = handle - # if recv_hook is not None: - # dbo_register_recv_hook(recv_hook) - # dbo_yield() - return (recv_hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens, - a1_scale, a1.dtype, quant_config)) + def _post_recv() -> mk.PrepareResultType: + expert_x, expert_x_scale = self._do_quant( + _expert_x, a1_scale, a1.dtype, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) - def _receiver( - self, - hook: Optional[Callable], - expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - expert_num_tokens: torch.Tensor, - a1_scale, - a1_dtype, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - if hook is not None: - hook() - - expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None - return expert_x, expert_x_scale, expert_tokens_meta, None, None + return (recv_hook, _post_recv) def prepare( self, @@ -202,11 +184,19 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) - return receiver() + recv_hook, post_recv = self.prepare_async( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + recv_hook() + return post_recv() def finalize( self, @@ -233,14 +223,15 @@ def finalize( # TODO (varun) : Enable zero copy mode dbo_maybe_run_recv_hook() - _, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output, - topk_ids, - combine_topk_weights, - handle, - async_finish=False, - zero_copy=False, - return_recv_hook=do_recv_hook, - out=output) + _, _, recv_hook = self.buffers[a2a_idx].low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=do_recv_hook, + out=output) if recv_hook is not None: - dbo_register_recv_hook(recv_hook) - dbo_yield() \ No newline at end of file + dbo_register_recv_hook(recv_hook, all_schedules=True) + dbo_yield(all_schedules=True) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index dcdefc06a27e..25db394eec17 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,11 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv -from vllm.v1.worker.ubatching import (dbo_enabled, - dbo_current_ubatch_id, - dbo_yield, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook) +from vllm.v1.worker.ubatching import (Schedule, dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -502,12 +499,14 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, class SharedResizableBuffer: + def __init__(self): self.buffer = None - + # NOTE: Assumes the first call to get() is the largest shape, # this is usually true due to the profile run. - def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): + def get(self, shape: tuple[int, ...], device: torch.device, + dtype: torch.dtype): shape_numel = prod(shape) if self.buffer is None or self.buffer.numel() < shape_numel: self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) @@ -584,14 +583,12 @@ def _do_fused_experts( # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = self.workspace13_buffer.get( - workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = self.workspace2_buffer.get( - workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + workspace13 = self.workspace13_buffer.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = self.workspace2_buffer.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -683,10 +680,9 @@ def _maybe_chunk_fused_experts( (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = self.fused_out_buffer.get( - fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + fused_out = self.fused_out_buffer.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -829,8 +825,7 @@ def forward( shared_output: torch.Tensor - if (not self.prepare_finalize.supports_async() - or self.shared_experts is None): + if not self.prepare_finalize.supports_async(): assert False # Run shared experts serially with dispatch. @@ -864,17 +859,23 @@ def forward( self.fused_experts.quant_config, ) - assert self.shared_experts is not None - shared_output = self.shared_experts(a1) + if dbo_register_recv_hook(hook, + schedules=(Schedule.MLP_OVERLAP, )): + hook = lambda: None + + dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) + + # assert self.shared_experts is not None + if self.shared_experts is not None: + assert False + shared_output = self.shared_experts(a1) - dbo_register_recv_hook(hook) - dbo_yield() + dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP, )) - if dbo_enabled(): - hook = None + hook() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = receiver(hook) + _expert_topk_weights) = receiver() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 1d9f49759a83..756fc10d26e9 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -14,8 +14,8 @@ _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, - dbo_yield_and_switch_from_comm_to_compute, - dbo_yield_and_switch_from_compute_to_comm) + dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield) logger = init_logger(__name__) @@ -206,7 +206,6 @@ def prepare_async( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - dbo_yield_and_switch_from_compute_to_comm() self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -219,50 +218,32 @@ def prepare_async( do_recv=False, ) - return lambda: self._receiver( - expert_num_tokens, - expert_x, - expert_x_scale, - a1q, - a1q_scale, - topk_ids, - bound_m, - orig_a_scale_block_shape, - ) - - def _receiver( - self, - expert_num_tokens: torch.Tensor, - expert_x: torch.Tensor, - expert_x_scale: Optional[torch.Tensor], - a1q: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - bound_m: Optional[torch.Tensor], - orig_a_scale_block_shape: Optional[int], - ) -> mk.PrepareResultType: + def _recv_hook(self): + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) - self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=topk_ids, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) - dbo_yield_and_switch_from_comm_to_compute() + def _post_recv() -> mk.PrepareResultType: + if expert_x_scale is not None: + expert_x_scale = expert_x_scale[:, :, : + orig_a_scale_block_shape] + assert expert_x_scale.ndim == 3 - if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - assert expert_x_scale.ndim == 3 + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None - return expert_x, expert_x_scale, expert_tokens_meta, None, None + return _recv_hook, _post_recv def prepare( self, @@ -276,7 +257,7 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async( + recv_hook, post_recv = self.prepare_async( a1, a1_scale, a2_scale, @@ -287,7 +268,8 @@ def prepare( apply_router_weight_on_input, quant_config, ) - return receiver() + recv_hook() + return post_recv() def finalize( self, @@ -321,12 +303,28 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - dbo_yield_and_switch_from_compute_to_comm() + dbo_maybe_run_recv_hook() self.a2as[a2a_idx].combine( out_tokens=output, indices=topk_ids.view(dtype=torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m, + do_send=False, + do_recv=True, ) - dbo_yield_and_switch_from_comm_to_compute() + + def recv_hook(): + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids.view(dtype=torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + + if recv_hook is not None: + dbo_register_recv_hook(recv_hook, all_schedules=True) + dbo_yield(all_schedules=True) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e4cab4fb8202..588663f44ce6 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -147,83 +146,61 @@ def __init__( self.n_local_physical_experts) self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - - if config.n_shared_experts is None: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - self.shared_experts = None - else: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) - self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - - if self.shared_experts is not None: - shared_output, final_hidden_states = fused_moe_out - else: - shared_output = None - final_hidden_states = fused_moe_out - - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - final_hidden_states *= self.routed_scaling_factor - elif self.shared_experts is not None: - assert shared_output is not None - shared_output *= (1. / self.routed_scaling_factor) - - if self.shared_experts is not None: - assert shared_output is not None - final_hidden_states += shared_output - + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + if hidden_states.dtype != torch.float16: + final_hidden_states = final_hidden_states + shared_output + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7d05f2ea6bb0..b53247f57f84 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -585,7 +585,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, ) - def build_for_cudagraph_capture(self, common_attn_metadata: CommonAttentionMetadata): + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): """ return - num_decodes: number of decode requests @@ -1269,11 +1270,6 @@ def forward( if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) - if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) - if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( @@ -1308,6 +1304,14 @@ def forward( layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) + #dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP,)) + + if has_prefill: + output[num_decode_tokens:] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata, layer._k_scale) + + if has_decode: output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index ed525adfe0ed..c7d557c355b2 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from contextlib import contextmanager +from enum import Enum from typing import Optional import torch @@ -12,6 +14,12 @@ _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] + +class Schedule(Enum): + MLP_OVERLAP = "mlp_overlap" + MLA_ATTN_OVERLAP = "mla_attn_overlap" + + class UBatchContext: """ Context manager for micro-batching synchronization using threading events. @@ -27,8 +35,7 @@ def __init__(self, cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event, - enable_async_comms: bool, - schedule: str = "default"): + schedule: Schedule = Schedule.MLP_OVERLAP): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream @@ -39,7 +46,6 @@ def __init__(self, self.current_stream = compute_stream self.gpu_comm_done_event = gpu_comm_done_event self.gpu_compute_done_event = gpu_compute_done_event - self.enable_async_comms = enable_async_comms self.schedule = schedule self.recv_hook = None @@ -111,9 +117,15 @@ def switch_to_comm_sync(self): self._signal_compute_done() self.update_stream(self.comm_stream) self._wait_comm_done() - + + def switch_to_compute_sync(self): + self._signal_comm_done() + self.update_stream(self.compute_stream) + self._wait_compute_done() + def maybe_run_recv_hook(self): if self.recv_hook is not None: + print("run recv hook", self.id, self.recv_hook) self.recv_hook() self.recv_hook = None @@ -122,7 +134,7 @@ def yield_(self): self._cpu_yield() if self.current_stream == current_stream(): self.update_stream(self.current_stream) - + def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream self._signal_compute_done() @@ -143,30 +155,83 @@ def yield_and_switch_from_comm_to_compute(self): def dbo_enabled() -> bool: return len(_THREAD_ID_TO_CONTEXT) > 0 + def dbo_current_ubatch_id() -> int: if len(_THREAD_ID_TO_CONTEXT) == 0: return 0 return _THREAD_ID_TO_CONTEXT[threading.get_ident()] -def _register_ubatch_function(func, context_offset): - def wrapper(*args, **kwargs): + +def _register_ubatch_function(func, all_schedules_default: bool = False): + + def wrapper(schedules: tuple[Schedule] = (), + all_schedules: bool = all_schedules_default): if len(_THREAD_ID_TO_CONTEXT) > 0: - ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + context_offset + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] ctx = _CURRENT_CONTEXTS[ctx_idx] - func(ctx, *args, **kwargs) + if all_schedules or ctx.schedule in schedules: + func(ctx) + return wrapper -dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(UBatchContext.yield_and_switch_from_compute_to_comm, 0) -dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(UBatchContext.yield_and_switch_from_comm_to_compute, 0) -dbo_yield = _register_ubatch_function(UBatchContext.yield_, 0) -dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook, 0) -dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync, 0) -def dbo_register_recv_hook(recv_hook): +dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( + UBatchContext.yield_and_switch_from_compute_to_comm) +dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( + UBatchContext.yield_and_switch_from_comm_to_compute) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) +# For `dbo_maybe_run_recv_hook` we usually want to run the recv hook for +# since its already conditional on a recv_hook being registered, so we default +# all_schedules to True to make the code a bit simpler. +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook, all_schedules_default=True) +dbo_switch_to_comm_sync = _register_ubatch_function( + UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute_sync = _register_ubatch_function( + UBatchContext.switch_to_compute_sync) + + +@contextmanager +def dbo_switch_to_compute_sync(): + if dbo_enabled(): + dbo_switch_to_compute_sync() + yield + dbo_switch_to_comm_sync() + else: + yield + + +@contextmanager +def dbo_switch_to_comm_sync(): + if dbo_enabled(): + dbo_switch_to_comm_sync() + yield + dbo_switch_to_compute_sync() + else: + yield + + +@contextmanager +def dbo_run_on_comm_async(): + if dbo_enabled(): + dbo_run_on_comm_async() + yield + else: + yield + + +def dbo_register_recv_hook(recv_hook, + schedules: tuple[Schedule] = (), + all_schedules: bool = False) -> bool: if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] - next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] - next_ctx.recv_hook = recv_hook + if all_schedules or _CURRENT_CONTEXTS[ctx_idx].schedule in schedules: + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + print("register recv hook", next_ctx.id, recv_hook) + next_ctx.recv_hook = recv_hook + return True + return False + def make_ubatch_contexts( num_micro_batches: int, @@ -175,8 +240,7 @@ def make_ubatch_contexts( forward_contexts: list[ForwardContext], ready_barrier: threading.Barrier, device: Optional[torch.device] = None, - enable_async_comms: bool = False, - schedule: str = "default", + schedule: Schedule = Schedule.MLP_OVERLAP, ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" """ @@ -206,7 +270,6 @@ def make_ubatch_contexts( num_micro_batches], gpu_comm_done_event=gpu_comm_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i], - enable_async_comms=enable_async_comms, schedule=schedule) ctxs.append(ctx) From 355acf14b45fe941d6020fcd17939a3e227c02fd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 2 Sep 2025 05:34:36 +0000 Subject: [PATCH 6/8] alt schedule Signed-off-by: Lucas Wilkinson --- vllm/compilation/ubatch_wrapper.py | 18 +++- vllm/config/parallel.py | 7 ++ vllm/engine/arg_utils.py | 6 ++ .../layers/fused_moe/modular_kernel.py | 9 +- vllm/model_executor/models/deepseek_v2.py | 101 +++++++++++------- vllm/v1/attention/backends/mla/common.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 11 +- vllm/v1/worker/ubatching.py | 74 ++++++++----- 8 files changed, 146 insertions(+), 83 deletions(-) diff --git a/vllm/compilation/ubatch_wrapper.py b/vllm/compilation/ubatch_wrapper.py index 5bc3df813395..d4802b1404f7 100644 --- a/vllm/compilation/ubatch_wrapper.py +++ b/vllm/compilation/ubatch_wrapper.py @@ -42,14 +42,15 @@ class CUDAGraphMetaData: class UBatchWrapper: def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, device: torch.cuda.device): + runtime_mode: CUDAGraphMode, device: torch.cuda.device, + delayed_start: bool = False): self.runnable = runnable self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.comm_stream = torch.cuda.Stream() self.device = device self.ready_barrier = threading.Barrier(3) - + self.delayed_start = delayed_start self.cudagraphs = {} self.cudagraph_wrapper = None @@ -75,7 +76,6 @@ def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: @torch.inference_mode() def _capture_ubatch_thread(results, ubatch_metadata): - # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) context = ubatch_metadata.context with torch.cuda.stream(context.compute_stream): _ = torch.cuda.current_blas_handle() @@ -170,7 +170,8 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, positions, inputs_embeds, intermediate_tensors, compute_stream, num_tokens_across_dp, batch_descriptor, - cudagraph_runtime_mode) -> list[UbatchMetadata]: + cudagraph_runtime_mode, + delayed_start: bool = False) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] @@ -186,6 +187,12 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode)) + # Map CLI/config schedule string to Schedule enum + schedule_str = self.vllm_config.parallel_config.microbatch_schedule + schedule = Schedule.MLP_OVERLAP + if schedule_str == Schedule.MLA_ATTN_OVERLAP.value: + schedule = Schedule.MLA_ATTN_OVERLAP + ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), comm_stream=self.comm_stream, @@ -193,7 +200,8 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, forward_contexts=forward_contexts, ready_barrier=self.ready_barrier, device=self.device, - schedule=Schedule.MLP_OVERLAP, + schedule=schedule, + delayed_start=delayed_start, ) ubatch_metadata: list[UbatchMetadata] = [] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index a0f250f24aab..9397f918b7e7 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -135,6 +135,13 @@ class ParallelConfig: request is greater than this threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" + microbatch_schedule: Literal["mlp_overlap", "mla_attn_overlap"] = "mlp_overlap" + """Schedule policy for microbatch overlap coordination. + + - "mlp_overlap": overlap MLP compute and communication across ubatches + - "mla_attn_overlap": overlap MLA attention and communication across ubatches + """ + enable_async_comms: bool = False """enable async comms""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 98120f71cf1f..8245c20bff0c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -317,6 +317,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_microbatching: bool = ParallelConfig.enable_microbatching microbatching_token_threshold: int = ParallelConfig.microbatching_token_threshold + microbatch_schedule: str = ParallelConfig.microbatch_schedule enable_async_comms: bool = ParallelConfig.enable_async_comms eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb @@ -680,6 +681,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["enable_microbatching"]) parallel_group.add_argument("--microbatching-token-threshold", **parallel_kwargs["microbatching_token_threshold"]) + parallel_group.add_argument( + "--microbatch-schedule", + dest="microbatch_schedule", + **parallel_kwargs["microbatch_schedule"]) parallel_group.add_argument("--enable-async-comms", **parallel_kwargs["enable_async_comms"]) parallel_group.add_argument("--enable-eplb", @@ -1303,6 +1308,7 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, enable_microbatching=self.enable_microbatching, microbatching_token_threshold=self.microbatching_token_threshold, + microbatch_schedule=self.microbatch_schedule, enable_async_comms=self.enable_async_comms, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 25db394eec17..4b9d01df82ec 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -847,7 +847,7 @@ def forward( else: # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() - hook, receiver = self.prepare_finalize.prepare_async( + _hook, receiver = self.prepare_finalize.prepare_async( a1, a1_scale, a2_scale, @@ -858,20 +858,15 @@ def forward( apply_router_weight_on_input, self.fused_experts.quant_config, ) - if dbo_register_recv_hook(hook, schedules=(Schedule.MLP_OVERLAP, )): hook = lambda: None + dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) - dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) - - # assert self.shared_experts is not None if self.shared_experts is not None: - assert False shared_output = self.shared_experts(a1) dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP, )) - hook() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 588663f44ce6..9162dd695564 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -49,6 +49,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -146,61 +147,83 @@ def __init__( self.n_local_physical_experts) self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: + + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + self.shared_experts = None + else: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + if self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( @@ -725,6 +748,8 @@ def forward( class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + delayed_dbo_start = True + packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b53247f57f84..dba6a78c924e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -218,6 +218,7 @@ infer_global_hyperparameters, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.ubatching import Schedule, dbo_yield try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1304,7 +1305,7 @@ def forward( layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) - #dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP,)) + dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP,)) if has_prefill: output[num_decode_tokens:] = self._forward_prefill( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b669020db264..0093692a8ddf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -84,7 +84,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -2348,10 +2347,16 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.vllm_config, runtime_mode=CUDAGraphMode.FULL) elif self.parallel_config.enable_microbatching: + delayed_start = getattr(self.model, "delayed_dbo_start", False) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper(self.model, self.vllm_config, CUDAGraphMode.FULL, self.device) + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.FULL, self.device, + delayed_start=delayed_start) else: - self.model = UBatchWrapper(self.model, self.vllm_config, CUDAGraphMode.NONE, self.device) + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.NONE, self.device, + delayed_start=delayed_start) def reload_weights(self) -> None: diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index c7d557c355b2..2a57751fcbaa 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -4,7 +4,8 @@ from contextlib import contextmanager from enum import Enum from typing import Optional - +from torch.library import Library +from functools import lru_cache import torch from vllm import forward_context @@ -18,6 +19,10 @@ class Schedule(Enum): MLP_OVERLAP = "mlp_overlap" MLA_ATTN_OVERLAP = "mla_attn_overlap" + +_SCHEDULE_WAIT_STAGES = { + Schedule.MLA_ATTN_OVERLAP: 2, +} class UBatchContext: @@ -35,6 +40,7 @@ def __init__(self, cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event, + started: bool = True, schedule: Schedule = Schedule.MLP_OVERLAP): self.id = id self.comm_stream = comm_stream @@ -47,6 +53,7 @@ def __init__(self, self.gpu_comm_done_event = gpu_comm_done_event self.gpu_compute_done_event = gpu_compute_done_event self.schedule = schedule + self.started = started self.recv_hook = None def __enter__(self): @@ -55,8 +62,13 @@ def __enter__(self): _CURRENT_CONTEXTS[self.id] = self self.ready_barrier.wait() + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) if self.id > 0 else 1 + for _ in range(wait_stages - 1): + self.yield_() + self.cpu_wait_event.wait() self.cpu_wait_event.clear() + self._restore_context() # Assume we start on the compute stream assert current_stream() == self.compute_stream @@ -64,10 +76,20 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT - _CURRENT_CONTEXTS[self.id] = None - del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) + if self.id == 0: + # Keep advance in the next micro-batch + for _ in range(wait_stages - 1): + self.yield_() + self.maybe_run_recv_hook() + self.cpu_signal_event.set() self.cpu_wait_event.clear() + + del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + _CURRENT_CONTEXTS[self.id] = None + self.current_stream = self.compute_stream torch.cuda.set_stream(self.current_stream) return False @@ -125,7 +147,6 @@ def switch_to_compute_sync(self): def maybe_run_recv_hook(self): if self.recv_hook is not None: - print("run recv hook", self.id, self.recv_hook) self.recv_hook() self.recv_hook = None @@ -162,6 +183,13 @@ def dbo_current_ubatch_id() -> int: return _THREAD_ID_TO_CONTEXT[threading.get_ident()] +def dbo_start(): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + ctx.started = True + + def _register_ubatch_function(func, all_schedules_default: bool = False): def wrapper(schedules: tuple[Schedule] = (), @@ -169,6 +197,8 @@ def wrapper(schedules: tuple[Schedule] = (), if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] ctx = _CURRENT_CONTEXTS[ctx_idx] + if ctx.started: + return if all_schedules or ctx.schedule in schedules: func(ctx) @@ -191,34 +221,19 @@ def wrapper(schedules: tuple[Schedule] = (), UBatchContext.switch_to_compute_sync) -@contextmanager -def dbo_switch_to_compute_sync(): - if dbo_enabled(): - dbo_switch_to_compute_sync() - yield - dbo_switch_to_comm_sync() - else: - yield - -@contextmanager -def dbo_switch_to_comm_sync(): - if dbo_enabled(): - dbo_switch_to_comm_sync() - yield - dbo_switch_to_compute_sync() - else: - yield +lib = Library("vllm_dbo", "DEF") +lib.define("start(Tensor! x) -> ()") # in-place, returns x +@torch.library.impl("vllm_dbo::start", "CompositeImplicitAutograd") +def _dbo_start_impl(x: torch.Tensor): + dbo_start() + return None -@contextmanager -def dbo_run_on_comm_async(): - if dbo_enabled(): - dbo_run_on_comm_async() - yield - else: - yield +@lru_cache(maxsize=1) +def dbo_debug_annotate(): + return True def dbo_register_recv_hook(recv_hook, schedules: tuple[Schedule] = (), @@ -227,7 +242,6 @@ def dbo_register_recv_hook(recv_hook, ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] if all_schedules or _CURRENT_CONTEXTS[ctx_idx].schedule in schedules: next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] - print("register recv hook", next_ctx.id, recv_hook) next_ctx.recv_hook = recv_hook return True return False @@ -241,6 +255,7 @@ def make_ubatch_contexts( ready_barrier: threading.Barrier, device: Optional[torch.device] = None, schedule: Schedule = Schedule.MLP_OVERLAP, + delayed_start: bool = False, ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" """ @@ -270,6 +285,7 @@ def make_ubatch_contexts( num_micro_batches], gpu_comm_done_event=gpu_comm_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i], + started=not delayed_start, schedule=schedule) ctxs.append(ctx) From 546206306bbb82d17154c1c8232b7b51e68705be Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 2 Sep 2025 19:59:01 +0000 Subject: [PATCH 7/8] refactor prepare finalize Signed-off-by: Lucas Wilkinson --- .../fused_moe/deepep_ht_prepare_finalize.py | 190 ++++++-------- .../fused_moe/deepep_ll_prepare_finalize.py | 118 ++++++--- .../layers/fused_moe/modular_kernel.py | 248 ++++++++++++------ .../layers/fused_moe/pplx_prepare_finalize.py | 145 ++++++---- 4 files changed, 415 insertions(+), 286 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 89bd322c6f80..cbe164cc821d 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -61,32 +61,63 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: return None return deep_ep.Buffer.get_combine_config(self.dp_size) - def _do_dispatch( + def _create_prepare_ops( self, - tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, - num_experts: int, + a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> Callable: + ) -> mk.PrepareResultType: + + # Apply router weights on input if requested (only supports topk=1) + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + # Quantize prior to dispatch for block-quantized path, otherwise defer + if quant_config.is_block_quantized: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + a1_post_scale = None + else: + a1q = a1 + a1q_scale = None + a1_post_scale = a1_scale - has_scales = token_scales is not None + # Inline dispatch (sync send+recv) + has_scales = a1q_scale is not None - dbo_yield_and_switch_from_compute_to_comm() (num_tokens_per_rank, num_tokens_per_rdma_rank, dispatch_expert_num_tokens, is_token_in_rank, event) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, + topk_idx=topk_ids, num_experts=num_experts, previous_event=None, async_finish=False, allocate_on_comm_stream=False) - token_data = tokens + token_data: Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]] + token_data = a1q if has_scales: - token_data = (tokens, token_scales) + token_data = (a1q, a1q_scale) + + ######################################################################## + yield # Pre-dispatch done + ######################################################################## ( token_data, expert_topk_ids, expert_topk_weights, @@ -98,14 +129,12 @@ def _do_dispatch( num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, num_tokens_per_expert=dispatch_expert_num_tokens, - topk_idx=rank_topk_ids, - topk_weights=rank_topk_weights, - # expert_alignment rounds the number of tokens per expert - # to this value. + topk_idx=topk_ids, + topk_weights=topk_weights, expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=self.async_prepare, + async_finish=False, allocate_on_comm_stream=False) dbo_yield_and_switch_from_comm_to_compute() @@ -113,36 +142,12 @@ def _do_dispatch( a2a_idx = dbo_current_ubatch_id() self.handles[a2a_idx] = handle - return lambda: self._receiver( - event, - has_scales, - token_data, - expert_topk_ids, - num_experts, - expert_num_tokens_per_expert_list, - expert_topk_weights, - a1_scale, - quant_config, - ) - - def _receiver( - self, - event: deep_ep.EventOverlap, - has_scales: bool, - token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], - expert_topk_ids: Optional[torch.Tensor], - num_experts: int, - expert_num_tokens_per_expert_list: list[int], - expert_topk_weights: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - if self.async_prepare: - event.current_stream_wait() - + # Unpack token data if has_scales: + assert isinstance(token_data, tuple) expert_x, expert_x_scale = token_data else: + assert isinstance(token_data, torch.Tensor) expert_x, expert_x_scale = token_data, None # The existing MOE kernels assume that all entries of topk_ids are @@ -183,58 +188,14 @@ def _receiver( per_act_token_quant=False, block_shape=quant_config.block_shape) + ######################################################################## + yield # Dispatch send+recv done (sync) + ######################################################################## + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def supports_async(self) -> bool: - return True - - def prepare_async( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> Callable: - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * topk_weights.to(a1.dtype) - - if quant_config.is_block_quantized: - # Quant and Dispatch - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - if a1q_scale is not None and a1q_scale.numel() == 1: - a1q_scale = a1q_scale.view(1, 1) - a1_post_scale = None - else: - a1q = a1 - a1q_scale = None - a1_post_scale = a1_scale - - return self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config) - - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -245,14 +206,14 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) - return receiver() - - def finalize( + ) -> mk.SyncPrepareOps: + return mk.SyncPrepareOps.from_generator( + self._create_prepare_ops(a1, a1_scale, a2_scale, topk_weights, + topk_ids, num_experts, expert_map, + apply_router_weight_on_input, + quant_config)) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -279,7 +240,10 @@ def finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) - dbo_yield_and_switch_from_compute_to_comm() + ######################################################################## + yield # Pre-combine done + ######################################################################## + combined_x, _, event = self.buffer.combine( x=fused_expert_output, handle=handle, @@ -288,6 +252,26 @@ def finalize( previous_event=None, async_finish=False, allocate_on_comm_stream=False) - dbo_yield_and_switch_from_comm_to_compute() # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) + + ######################################################################## + yield # Combine send-recv done + ######################################################################## + + return None + + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.SyncFinalizeOps: + return mk.SyncFinalizeOps.from_generator( + self._create_finalize_ops(output, fused_expert_output, + topk_weights, topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl)) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 4b4c0cf4e922..9cdf9725ce46 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional, Union, Callable import deep_ep import torch @@ -11,9 +11,7 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) -from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -115,7 +113,7 @@ def _do_quant( def supports_async(self) -> bool: return True - def prepare_async( + def _create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -126,10 +124,9 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.ReceiverType: - - hidden_size = a1.size(1) + ) -> mk.PrepareResultType: a2a_idx = dbo_current_ubatch_id() + hidden_size = a1.size(1) if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ @@ -147,6 +144,10 @@ def prepare_async( assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) + + ######################################################################## + yield # Pre-dispatch done + ######################################################################## # Dispatch _expert_x, expert_num_tokens, handle, _, recv_hook= \ @@ -158,21 +159,28 @@ def prepare_async( async_finish=False, return_recv_hook=True) self.handles[a2a_idx] = handle + + ######################################################################## + yield # Dispatch send done + ######################################################################## + + recv_hook() + + ######################################################################## + yield # Dispatch recv done + ######################################################################## - def _post_recv() -> mk.PrepareResultType: - expert_x, expert_x_scale = self._do_quant( - _expert_x, a1_scale, a1.dtype, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) - - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, - expert_num_tokens_cpu=None) + expert_x, expert_x_scale = self._do_quant( + _expert_x, a1_scale, a1.dtype, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) - return expert_x, expert_x_scale, expert_tokens_meta, None, None + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) - return (recv_hook, _post_recv) + return expert_x, expert_x_scale, expert_tokens_meta, None, None - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -183,22 +191,20 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - recv_hook, post_recv = self.prepare_async( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - num_experts, - expert_map, - apply_router_weight_on_input, - quant_config, - ) - recv_hook() - return post_recv() - - def finalize( + ) -> mk.AsyncPrepareOps: + return mk.AsyncPrepareOps.from_generator( + self._create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config)) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -212,7 +218,6 @@ def finalize( ), ("Weight application and reduction happens in the combine kernel.") a2a_idx = dbo_current_ubatch_id() - do_recv_hook = dbo_enabled() handle = self.handles[a2a_idx] assert handle is not None @@ -220,9 +225,11 @@ def finalize( if apply_router_weight_on_input: # weights have already been applied. combine_topk_weights = torch.ones_like(topk_weights) + + ######################################################################## + yield # Pre-combine done + ######################################################################## - # TODO (varun) : Enable zero copy mode - dbo_maybe_run_recv_hook() _, _, recv_hook = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, @@ -230,8 +237,35 @@ def finalize( handle, async_finish=False, zero_copy=False, - return_recv_hook=do_recv_hook, + return_recv_hook=True, out=output) - if recv_hook is not None: - dbo_register_recv_hook(recv_hook, all_schedules=True) - dbo_yield(all_schedules=True) + + ######################################################################## + yield # Combine send done + ######################################################################## + + recv_hook() + + ######################################################################## + yield # Combine recv done + ######################################################################## + + return None + + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.AsyncFinalizeOps: + return mk.AsyncFinalizeOps.from_generator( + self._create_finalize_ops(output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl)) + \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4b9d01df82ec..4955079e90dd 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Callable, Optional, Union, final +from typing import Callable, Optional, Union, final, Generator import torch @@ -166,6 +166,119 @@ def apply(self, output: Optional[torch.Tensor], ReceiverType = Callable[[], PrepareResultType] +# +# Prepare and Finalize Op Chains +# +# The prepare and finalize functions are broken down into a chain of sequential +# operations/steps. This + +class _PhasedGen[R]: + """ + Enforce an exact number of yields (phases), then a final return. + + Contract: + - The generator must yield exactly `expected_yields` times. + - The next advance must StopIteration with a return value (may be None). + - Early StopIteration or extra yields raise RuntimeError. + - Duplicate step/finish after completion raises RuntimeError. + """ + __slots__ = ("_gen", "_expected", "_steps", "_done", "_ret") + + def __init__(self, gen: Generator[None, None, R], expected_yields: int): + self._gen = gen + self._expected = expected_yields + self._steps = 0 + self._done = False + self._ret: Optional[R] = None + + def step(self, label: str) -> None: + if self._done: + raise RuntimeError(f"Generator already finished; unexpected '{label}'.") + if self._steps >= self._expected: + raise RuntimeError( + f"Too many steps: called '{label}' after {self._expected} phases; " + "expected to finish instead." + ) + try: + next(self._gen) + except StopIteration: + raise RuntimeError( + f"Generator ended early during '{label}' " + f"(completed {self._steps}/{self._expected} phases)." + ) + self._steps += 1 + + def finish(self, label: str) -> R: + if self._done: + raise RuntimeError(f"Generator already finished; duplicate '{label}'.") + if self._steps != self._expected: + raise RuntimeError( + f"Cannot finish at '{label}': only {self._steps}/" + f"{self._expected} phases completed." + ) + try: + next(self._gen) + except StopIteration as e: + self._done = True + self._ret = e.value # may be None + return self._ret # type: ignore[return-value] + else: + raise RuntimeError( + f"Generator yielded more than expected ({self._expected}); " + f"should have finished at '{label}'." + ) + +@dataclass +class AsyncOps[R]: + """ + 3-phase async: + 1) prepare() + 2) send() + 3) recv() + 4) finish() -> R + """ + prepare: Callable[[], None] + send: Callable[[], None] + recv: Callable[[], None] + finish: Callable[[], R] + + @classmethod + def from_generator(cls, gen: Generator[None, None, R]) -> 'AsyncOps[R]': + ph = _PhasedGen[R](gen, expected_yields=3) + return cls( + prepare=lambda: ph.step("prepare"), + send=lambda: ph.step("send"), + recv=lambda: ph.step("recv"), + finish=lambda: ph.finish("finish"), + ) + + +@dataclass +class SyncOps[R]: + """ + 2-phase sync: + 1) prepare() + 2) send_recv() + 3) finish() -> R + """ + prepare: Callable[[], None] + send_recv: Callable[[], None] + finish: Callable[[], R] + + @classmethod + def from_generator(cls, gen: Generator[None, None, R]) -> 'SyncOps[R]': + ph = _PhasedGen[R](gen, expected_yields=2) + return cls( + prepare=lambda: ph.step("prepare"), + send_recv=lambda: ph.step("send_recv"), + finish=lambda: ph.finish("finish"), + ) + +AsyncPrepareOps = AsyncOps[PrepareResultType] +SyncPrepareOps = SyncOps[PrepareResultType] +AsyncFinalizeOps = AsyncOps[None] +SyncFinalizeOps = SyncOps[None] + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -174,7 +287,7 @@ class FusedMoEPrepareAndFinalize(ABC): """ @abstractmethod - def prepare( + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -185,7 +298,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> PrepareResultType: + ) -> Union[SyncPrepareOps, AsyncPrepareOps]: """ Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. @@ -211,53 +324,8 @@ def prepare( """ raise NotImplementedError - def supports_async(self) -> bool: - """ - Indicates whether or not this class implements prepare_async. - """ - return False - - def prepare_async( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> ReceiverType: - """ - Perform any quantization (and/or) dispatching needed for this kernel - but do not wait for results from other workers. - - a1: The (unquantized) input to the MoE layer. - - a1_scale: Optional scales for a1 - - a2_scale: Optional scales for the second MoE gemm. Required to make - sure the quantization is consistent for both gemms. - - topk_ids: The topk ids. - - topk_weights: The topk weights. - - num_experts: The total number of experts in the global expert space. - - expert_map: A tensor mapping expert indices from the global expert - space to the local expert space of the expert parallel shard. - - apply_router_weight_on_input: When True, apply the weights to the - activations, before quantization + dispatching. - - Returns a callback that when invoked waits for results from other - workers and has the same return signature as `prepare`, e.g. - - receiver = obj.prepare_async(...) - a, a_scales, expert_meta, topk_ids, topk_weights = receiver() - - is equivalent to: - - a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) - """ - raise NotImplementedError - @abstractmethod - def finalize( + def create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -265,7 +333,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, - ) -> None: + ) -> Union[SyncFinalizeOps, AsyncFinalizeOps]: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -824,53 +892,48 @@ def forward( global_num_experts = local_num_experts shared_output: torch.Tensor + + prepare_ops = self.prepare_finalize.create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) - if not self.prepare_finalize.supports_async(): - assert False + prepare_ops.prepare() + if isinstance(prepare_ops, SyncOps): # Run shared experts serially with dispatch. if self.shared_experts is not None: shared_output = self.shared_experts(a1) - - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + prepare_ops.send_recv() else: + assert isinstance(prepare_ops, AsyncOps) + # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() - _hook, receiver = self.prepare_finalize.prepare_async( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) - if dbo_register_recv_hook(hook, - schedules=(Schedule.MLP_OVERLAP, )): - hook = lambda: None - dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) + prepare_ops.send() + + recv_done = dbo_register_recv_hook( + lambda: prepare_ops.recv(), + schedules=(Schedule.MLP_OVERLAP, )) + dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) if self.shared_experts is not None: shared_output = self.shared_experts(a1) dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP, )) - hook() - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = receiver() + if not recv_done: + prepare_ops.recv() + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = prepare_ops.finish() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -909,7 +972,7 @@ def forward( apply_router_weight_on_input=apply_router_weight_on_input, ) - self.prepare_finalize.finalize( + finalize_ops = self.prepare_finalize.create_finalize_ops( output, fused_out, topk_weights, @@ -917,6 +980,23 @@ def forward( apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) + + if isinstance(finalize_ops, SyncOps): + finalize_ops.prepare() + finalize_ops.send_recv() + finalize_ops.finish() + else: + assert isinstance(finalize_ops, AsyncOps) + finalize_ops.prepare() + dbo_maybe_run_recv_hook() + finalize_ops.send() + + if dbo_register_recv_hook( + lambda: finalize_ops.recv(), all_schedules=True): + dbo_yield(all_schedules=True) + else: + finalize_ops.recv() + finalize_ops.finish() if self.shared_experts is None: return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 756fc10d26e9..7b5ece9c1f40 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -13,9 +13,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up -from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -95,7 +93,7 @@ def num_dispatchers(self) -> int: def supports_async(self) -> bool: return True - def prepare_async( + def _create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -206,6 +204,10 @@ def prepare_async( # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None + ######################################################################## + yield # Pre-dispatch done + ######################################################################## + self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -218,34 +220,37 @@ def prepare_async( do_recv=False, ) - def _recv_hook(self): - self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=topk_ids, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) + ######################################################################## + yield # Dispatch send done + ######################################################################## - def _post_recv() -> mk.PrepareResultType: - if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, : - orig_a_scale_block_shape] - assert expert_x_scale.ndim == 3 + self.a2as[a2a_idx].dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, - expert_num_tokens_cpu=None) + ######################################################################## + yield # Dispatch recv done + ######################################################################## - return expert_x, expert_x_scale, expert_tokens_meta, None, None + if expert_x_scale is not None: + expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] + assert expert_x_scale.ndim == 3 - return _recv_hook, _post_recv + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, + expert_num_tokens_cpu=None) - def prepare( + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def create_prepare_ops( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -256,22 +261,21 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - recv_hook, post_recv = self.prepare_async( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - num_experts, - expert_map, - apply_router_weight_on_input, - quant_config, - ) - recv_hook() - return post_recv() - - def finalize( + ) -> mk.AsyncPrepareOps: + return mk.AsyncPrepareOps.from_generator( + self._create_prepare_ops( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + )) + + def _create_finalize_ops( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -303,7 +307,24 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - dbo_maybe_run_recv_hook() + ######################################################################## + yield # Pre-combine done + ######################################################################## + + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids.view(dtype=torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + ######################################################################## + yield # Combine send done (no-op for pplx combine) + ######################################################################## + self.a2as[a2a_idx].combine( out_tokens=output, indices=topk_ids.view(dtype=torch.uint32), @@ -314,17 +335,27 @@ def finalize( do_recv=True, ) - def recv_hook(): - self.a2as[a2a_idx].combine( - out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) + ######################################################################## + yield # Combine recv done + ######################################################################## + + return None - if recv_hook is not None: - dbo_register_recv_hook(recv_hook, all_schedules=True) - dbo_yield(all_schedules=True) + def create_finalize_ops( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> mk.AsyncFinalizeOps: + return mk.AsyncFinalizeOps.from_generator( + self._create_finalize_ops( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + )) From f0c0a5c9b6aba8773114d4555814b7056df2aee7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 3 Sep 2025 05:09:24 +0000 Subject: [PATCH 8/8] alt schedules Signed-off-by: Lucas Wilkinson --- vllm/compilation/ubatch_wrapper.py | 4 +- vllm/config/parallel.py | 4 +- .../layers/fused_moe/modular_kernel.py | 18 +++-- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/worker/ubatching.py | 68 ++++++++++++++----- 5 files changed, 71 insertions(+), 25 deletions(-) diff --git a/vllm/compilation/ubatch_wrapper.py b/vllm/compilation/ubatch_wrapper.py index d4802b1404f7..ea3adedf84fa 100644 --- a/vllm/compilation/ubatch_wrapper.py +++ b/vllm/compilation/ubatch_wrapper.py @@ -190,8 +190,8 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, # Map CLI/config schedule string to Schedule enum schedule_str = self.vllm_config.parallel_config.microbatch_schedule schedule = Schedule.MLP_OVERLAP - if schedule_str == Schedule.MLA_ATTN_OVERLAP.value: - schedule = Schedule.MLA_ATTN_OVERLAP + if schedule_str == Schedule.ATTN_SHARED_OVERLAP.value: + schedule = Schedule.ATTN_SHARED_OVERLAP ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9397f918b7e7..f381729e1a07 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -135,11 +135,11 @@ class ParallelConfig: request is greater than this threshold, microbatching will be used. Otherwise, the request will be processed in a single batch.""" - microbatch_schedule: Literal["mlp_overlap", "mla_attn_overlap"] = "mlp_overlap" + microbatch_schedule: Literal["mlp_overlap", "ATTN_SHARED_OVERLAP"] = "mlp_overlap" """Schedule policy for microbatch overlap coordination. - "mlp_overlap": overlap MLP compute and communication across ubatches - - "mla_attn_overlap": overlap MLA attention and communication across ubatches + - "ATTN_SHARED_OVERLAP": overlap MLA attention and communication across ubatches """ enable_async_comms: bool = False diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4955079e90dd..4b58f81bc943 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -14,7 +14,8 @@ _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv from vllm.v1.worker.ubatching import (Schedule, dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) + dbo_register_recv_hook, dbo_yield, + dbo_current_schedule) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -921,13 +922,17 @@ def forward( recv_done = dbo_register_recv_hook( lambda: prepare_ops.recv(), - schedules=(Schedule.MLP_OVERLAP, )) + schedules=(Schedule.MLP_OVERLAP, Schedule.MLP_SHARED_OVERLAP)) dbo_yield(schedules=(Schedule.MLP_OVERLAP, )) - if self.shared_experts is not None: + # If we are using the MLP_SHARED_OVERLAP schedule, we overlap with + # the combine instead of the dispatch. + # TODO(lucas): refactor this scheduling logic + if self.shared_experts is not None \ + and dbo_current_schedule() != Schedule.MLP_SHARED_OVERLAP: shared_output = self.shared_experts(a1) - dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP, )) + dbo_yield(schedules=(Schedule.ATTN_SHARED_OVERLAP, Schedule.MLP_SHARED_OVERLAP)) if not recv_done: prepare_ops.recv() @@ -991,6 +996,11 @@ def forward( dbo_maybe_run_recv_hook() finalize_ops.send() + # If we didn't overlap with the dispatch overlap with the combine + # TODO(lucas): refactor this scheduling logic + if self.shared_experts is not None and shared_output is None: + shared_output = self.shared_experts(a1) + if dbo_register_recv_hook( lambda: finalize_ops.recv(), all_schedules=True): dbo_yield(all_schedules=True) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index dba6a78c924e..3ff2786f7852 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1305,7 +1305,7 @@ def forward( layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) - dbo_yield(schedules=(Schedule.MLA_ATTN_OVERLAP,)) + dbo_yield(schedules=(Schedule.ATTN_SHARED_OVERLAP,)) if has_prefill: output[num_decode_tokens:] = self._forward_prefill( diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 2a57751fcbaa..85109f6f8244 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -17,11 +17,33 @@ class Schedule(Enum): + # Schedule notation legend: + # S = Shared expert + # A0 = MLA qkv pro, + # A1 = Core attn + out proj + MoE gate + # D = Dispatch + # C = Combine + + # Comp: |-A0₀-A1₀-S₀-||-MLP₁-||-MLP₀-||-A0₁-A1₁-S₁-| + # Comm: |-----D₁-----||--D₀--||--C₁--||-----C₀-----| + # Order: D₁ send, A0₀, A1₀, S₀, D₁ recv, D₀ send, MLP₁, D₀ recv, + # C₁ send, MLP₀, C₁ recv, C₀ send, A0₁, A1₁, S₁, C₀ recv. MLP_OVERLAP = "mlp_overlap" - MLA_ATTN_OVERLAP = "mla_attn_overlap" + + # Comp: |-A0₀-A1₀-||-MLP₁-||-S₁-MLP₀-||-S₀-A0₁-A1₁-| + # Comm: |----D₁---||--D₀--||----C₁---||-----C₀-----| + # Order: D₁ send, A0₀, A1₀, D₁ recv, D₀ send, MLP₁, D₀ recv, + # C₁ send, S₁, MLP₀, C₁ recv, C₀ send, S₀, A0₁, A1₁, C₀ recv. + MLP_SHARED_OVERLAP = "mlp_shared_overlap" + + # Comp: |-S₀-A0₁-|-MLP₀-|-A1₁-||-S₁-A0₀-|-MLP₁-|-A1₀-| + # Comm: |---D₀---| |-C₀--||---D₁---| |-C₁--| + # Order: D₀ send, S₀, A0₁, D₀ recv, MLP₀, C₀ send, A1₁, C₀ recv, + # D₁ send, S₁, A0₀, D₁ recv, MLP₁, C₁ send, A1₀, C₁ recv. + ATTN_SHARED_OVERLAP = "attn_shared_overlap" -_SCHEDULE_WAIT_STAGES = { - Schedule.MLA_ATTN_OVERLAP: 2, +_SCHEDULE_WAIT_STAGES = { # Default is 1 + Schedule.ATTN_SHARED_OVERLAP: 2, } @@ -62,13 +84,14 @@ def __enter__(self): _CURRENT_CONTEXTS[self.id] = self self.ready_barrier.wait() - wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) if self.id > 0 else 1 - for _ in range(wait_stages - 1): - self.yield_() - self.cpu_wait_event.wait() self.cpu_wait_event.clear() + if self.id > 0: + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) + for _ in range(wait_stages - 1): + self._cpu_yield(check_context=False) + self._restore_context() # Assume we start on the compute stream assert current_stream() == self.compute_stream @@ -77,13 +100,15 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT - wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) if self.id == 0: - # Keep advance in the next micro-batch + # Keep advancing the next micro-batch + wait_stages = _SCHEDULE_WAIT_STAGES.get(self.schedule, 1) for _ in range(wait_stages - 1): self.yield_() + # Cleanup and trailing recv hooks self.maybe_run_recv_hook() - + + self.maybe_run_recv_hook() self.cpu_signal_event.set() self.cpu_wait_event.clear() @@ -122,12 +147,17 @@ def stream_string(self): assert self.current_stream == self.comm_stream return "COMM" - def _cpu_yield(self): + def _cpu_yield(self, check_context: bool = True): # It is critical for correctness that only one thread is running # at a time. These asserts just make sure that this is the only # thread running before waking the other one up and going to sleep - assert forward_context._forward_context == self.forward_context - assert current_stream() == self.current_stream + print(f"CPU yield {self.id} {type(forward_context._forward_context)} {type(self.forward_context)}") + assert ( + not check_context or + forward_context._forward_context is self.forward_context) + assert ( + not check_context or + current_stream() == self.current_stream) assert not self.cpu_wait_event.is_set() self.cpu_signal_event.set() @@ -153,7 +183,7 @@ def maybe_run_recv_hook(self): def yield_(self): self.current_stream = current_stream() self._cpu_yield() - if self.current_stream == current_stream(): + if self.current_stream != current_stream(): self.update_stream(self.current_stream) def yield_and_switch_from_compute_to_comm(self): @@ -190,6 +220,11 @@ def dbo_start(): ctx.started = True +def dbo_current_schedule() -> Optional[Schedule]: + if len(_THREAD_ID_TO_CONTEXT) == 0: + return None + return _CURRENT_CONTEXTS[dbo_current_ubatch_id()].schedule + def _register_ubatch_function(func, all_schedules_default: bool = False): def wrapper(schedules: tuple[Schedule] = (), @@ -197,7 +232,7 @@ def wrapper(schedules: tuple[Schedule] = (), if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] ctx = _CURRENT_CONTEXTS[ctx_idx] - if ctx.started: + if not ctx.started: return if all_schedules or ctx.schedule in schedules: func(ctx) @@ -221,7 +256,8 @@ def wrapper(schedules: tuple[Schedule] = (), UBatchContext.switch_to_compute_sync) - +# DBO start needs to be callable from inside the torch compile region so +# we register it as a custom op. lib = Library("vllm_dbo", "DEF") lib.define("start(Tensor! x) -> ()") # in-place, returns x