Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/design/fused_moe_modular_kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)


![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks")

### FusedMoEPermuteExpertsUnpermute
Expand Down
8 changes: 7 additions & 1 deletion examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,13 @@ def parse_args():
parser.add_argument(
"--enable-microbatching",
action="store_true",
help=("Enable microbatched execution"),
help=("Enable microbatched execution")
)
parser.add_argument(
"--compilation-config",
type=int,
default=0,
help=("Compilation optimization (O) level 0-3."),
)
parser.add_argument(
"--compilation-config",
Expand Down
21 changes: 16 additions & 5 deletions vllm/compilation/ubatch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts, Schedule

Check failure on line 18 in vllm/compilation/ubatch_wrapper.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/ubatch_wrapper.py:18:81: E501 Line too long (82 > 80)

logger = init_logger(__name__)

Expand All @@ -40,13 +40,15 @@
class UBatchWrapper:

def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
runtime_mode: CUDAGraphMode, device: torch.cuda.device,
delayed_start: bool = False):

self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
self.ready_barrier = threading.Barrier(3)

self.delayed_start = delayed_start
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

self.cudagraph_wrapper = None
Expand Down Expand Up @@ -185,7 +187,8 @@
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
positions, inputs_embeds, intermediate_tensors,
compute_stream, dp_metadata, batch_descriptor,
cudagraph_runtime_mode) -> list[UbatchMetadata]:
cudagraph_runtime_mode,
delayed_start: bool = False) -> list[UbatchMetadata]:

# Create one forward context per ubatch
forward_contexts = []
Expand All @@ -198,14 +201,22 @@
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode))

# Map CLI/config schedule string to Schedule enum
schedule_str = self.vllm_config.parallel_config.microbatch_schedule
schedule = Schedule.MLP_OVERLAP
if schedule_str == Schedule.ATTN_SHARED_OVERLAP.value:
schedule = Schedule.ATTN_SHARED_OVERLAP

ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier)
ready_barrier=self.ready_barrier,
delayed_start=delayed_start)


ubatch_metadata: list[UbatchMetadata] = []

Check failure on line 219 in vllm/compilation/ubatch_wrapper.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/compilation/ubatch_wrapper.py:219:13: F841 Local variable `schedule` is assigned to but never used
for i, ubatch_slice in enumerate(ubatch_slices):
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
sliced_intermediate_tensors = \
Expand Down
7 changes: 7 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@
request is greater than this threshold, microbatching will be used.
Otherwise, the request will be processed in a single batch."""

microbatch_schedule: Literal["mlp_overlap", "ATTN_SHARED_OVERLAP"] = "mlp_overlap"
"""Schedule policy for microbatch overlap coordination.

- "mlp_overlap": overlap MLP compute and communication across ubatches
- "ATTN_SHARED_OVERLAP": overlap MLA attention and communication across ubatches
"""

Check failure on line 143 in vllm/config/parallel.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/parallel.py:143:81: E501 Line too long (84 > 80)

ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_microbatching: bool = ParallelConfig.enable_microbatching
microbatching_token_threshold: int = ParallelConfig.microbatching_token_threshold

Check failure on line 319 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:319:81: E501 Line too long (85 > 80)
microbatch_schedule: str = ParallelConfig.microbatch_schedule
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
num_redundant_experts: int = EPLBConfig.num_redundant_experts
Expand Down Expand Up @@ -682,6 +683,12 @@
**parallel_kwargs["enable_microbatching"])
parallel_group.add_argument("--microbatching-token-threshold",
**parallel_kwargs["microbatching_token_threshold"])
parallel_group.add_argument(
"--microbatch-schedule",
dest="microbatch_schedule",
**parallel_kwargs["microbatch_schedule"])
parallel_group.add_argument("--enable-async-comms",
**parallel_kwargs["enable_async_comms"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config",
Expand Down Expand Up @@ -1304,6 +1311,7 @@
enable_expert_parallel=self.enable_expert_parallel,
enable_microbatching=self.enable_microbatching,
microbatching_token_threshold=self.microbatching_token_threshold,
microbatch_schedule=self.microbatch_schedule,
enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config,
max_parallel_loading_workers=self.max_parallel_loading_workers,
Expand Down
158 changes: 86 additions & 72 deletions vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,63 @@
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)

def _do_dispatch(
def _create_prepare_ops(
self,
tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> Callable:
) -> mk.PrepareResultType:

# Apply router weights on input if requested (only supports topk=1)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * topk_weights.to(a1.dtype)

# Quantize prior to dispatch for block-quantized path, otherwise defer
if quant_config.is_block_quantized:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
a1_post_scale = None
else:
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale

Check failure on line 96 in vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py:96:13: F841 Local variable `a1_post_scale` is assigned to but never used

has_scales = token_scales is not None
# Inline dispatch (sync send+recv)
has_scales = a1q_scale is not None

(num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
topk_idx=topk_ids,
num_experts=num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)

token_data = tokens
token_data: Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]

Check failure on line 110 in vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py:110:81: E501 Line too long (84 > 80)
token_data = a1q
if has_scales:
token_data = (tokens, token_scales)
token_data = (a1q, a1q_scale)

########################################################################
yield # Pre-dispatch done
########################################################################

(
token_data, expert_topk_ids, expert_topk_weights,
Expand All @@ -94,10 +126,8 @@
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
topk_idx=topk_ids,
topk_weights=topk_weights,
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
Expand Down Expand Up @@ -131,9 +161,12 @@
if self.async_prepare:
event.current_stream_wait()

# Unpack token data
if has_scales:
assert isinstance(token_data, tuple)
expert_x, expert_x_scale = token_data
else:
assert isinstance(token_data, torch.Tensor)
expert_x, expert_x_scale = token_data, None

# The existing MOE kernels assume that all entries of topk_ids are
Expand Down Expand Up @@ -174,58 +207,14 @@
per_act_token_quant=False,
block_shape=quant_config.block_shape)

########################################################################
yield # Dispatch send+recv done (sync)
########################################################################

return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)

def supports_async(self) -> bool:
return True

def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> Callable:

if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * topk_weights.to(a1.dtype)

if quant_config.is_block_quantized:
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
a1_post_scale = None
else:
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale

return self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)

def prepare(
def create_prepare_ops(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
Expand All @@ -236,14 +225,14 @@
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()

def finalize(
) -> mk.SyncPrepareOps:
return mk.SyncPrepareOps.from_generator(
self._create_prepare_ops(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config))

def _create_finalize_ops(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
Expand All @@ -268,6 +257,10 @@
apply_router_weight_on_input=apply_router_weight_on_input,
)

########################################################################
yield # Pre-combine done
########################################################################

combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=self.handle,
Expand All @@ -278,3 +271,24 @@
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)

########################################################################
yield # Combine send-recv done
########################################################################

return None

def create_finalize_ops(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> mk.SyncFinalizeOps:
return mk.SyncFinalizeOps.from_generator(
self._create_finalize_ops(output, fused_expert_output,
topk_weights, topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl))
Loading
Loading