Skip to content

Commit 268b1c5

Browse files
robertgshaw2-redhatRobert Shawyewentao256
authored
[MoE Refactor][13/N] Convert FI to Use PFNoEP (vllm-project#31533)
Signed-off-by: Robert Shaw <[email protected]> Signed-off-by: Robert Shaw <[email protected]> Signed-off-by: Robert Shaw <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent 4f9ce35 commit 268b1c5

File tree

7 files changed

+78
-168
lines changed

7 files changed

+78
-168
lines changed

vllm/model_executor/layers/fused_moe/all2all_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1616
FusedMoEPrepareAndFinalize,
1717
)
18+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
19+
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
20+
)
1821
from vllm.platforms import current_platform
1922
from vllm.utils.import_utils import has_deep_ep, has_pplx
2023

@@ -77,10 +80,17 @@ def maybe_make_prepare_finalize(
7780

7881
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
7982

80-
# TODO: could allow this now
81-
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
83+
if moe.use_flashinfer_cutlass_kernels:
84+
assert quant_config is not None
85+
use_deepseek_fp8_block_scale = (
86+
quant_config is not None and quant_config.is_block_quantized
87+
)
88+
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
89+
moe=moe,
90+
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
91+
)
8292

83-
if moe.use_pplx_kernels:
93+
elif moe.use_pplx_kernels:
8494
assert quant_config is not None
8595

8696
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
)
1111
from vllm.forward_context import get_forward_context
1212
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
13+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
14+
MoEPrepareAndFinalizeNoEP,
15+
)
1316
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1417
TopKWeightAndReduceNoOP,
1518
)
@@ -349,14 +352,23 @@ def create_flashinfer_prepare_finalize(
349352
use_nvfp4: bool = False,
350353
enable_alltoallv: bool = False,
351354
use_deepseek_fp8_block_scale: bool = False,
352-
) -> FlashInferCutlassMoEPrepareAndFinalize:
355+
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
353356
"""Factory function to create the appropriate FlashInfer implementation."""
357+
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
358+
# once we complete the FP8 refactor.
354359
if use_nvfp4:
355360
if enable_alltoallv:
356361
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
357362
else:
358363
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
359-
# FP8 path currently supported via AllGather; optionally enable block-scale
360-
return FlashInferAllGatherMoEPrepareAndFinalize(
361-
use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
362-
)
364+
365+
# FP8 DP path currently supported via AllGather.
366+
if use_dp:
367+
return FlashInferAllGatherMoEPrepareAndFinalize(
368+
use_dp=True,
369+
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
370+
)
371+
else:
372+
# NOTE(rob): CUTLASS FP8 block quant executes the input
373+
# quantzation and grouped gemm in a single kernel.
374+
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)

vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def make(
4949
prepare_finalize,
5050
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
5151
shared_experts,
52-
getattr(moe_layer, "shared_experts_stream", None),
5352
moe_parallel_config=moe_layer.moe_parallel_config,
5453
),
5554
)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,14 @@ def __init__(
356356
# TODO: Remove this after more extensive testings with TP/DP
357357
# and other execution modes
358358
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
359-
logger.info_once("Disabling MoE shared_experts cuda stream")
359+
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
360360
self.shared_experts_stream = None
361361
else:
362362
# TODO(rob): enable shared expert overlap with non-cuda-alike.
363363
# aux_stream() returns None on non-cuda-alike platforms.
364364
self.shared_experts_stream = aux_stream()
365365
if self.shared_experts_stream is not None:
366-
logger.info_once(
366+
logger.debug_once(
367367
"Enabled separate cuda stream for MoE shared_experts", scope="local"
368368
)
369369

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 5 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
count_expert_num_tokens,
2222
disable_inplace,
2323
)
24-
from vllm.platforms import current_platform
2524
from vllm.utils.math_utils import cdiv
2625
from vllm.v1.worker.ubatching import (
2726
dbo_enabled,
@@ -682,14 +681,12 @@ def __init__(
682681
prepare_finalize: FusedMoEPrepareAndFinalize,
683682
fused_experts: FusedMoEPermuteExpertsUnpermute,
684683
shared_experts: torch.nn.Module | None = None,
685-
shared_experts_stream: torch.cuda.Stream | None = None,
686684
moe_parallel_config: FusedMoEParallelConfig | None = None,
687685
):
688686
super().__init__()
689687
self.prepare_finalize = prepare_finalize
690688
self.fused_experts = fused_experts
691689
self.shared_experts = shared_experts
692-
self.shared_experts_stream = shared_experts_stream
693690

694691
# prefer an explicit FusedMoEParallelConfig when available (from
695692
# FusedMoE layers / tests).
@@ -904,34 +901,6 @@ def _slice_expert_tokens_metadata(
904901
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
905902
)
906903

907-
def _maybe_setup_shared_experts_stream(
908-
self, hidden_states: torch.Tensor
909-
) -> tuple[bool, torch.Tensor | None]:
910-
# decide whether to run shared experts on a separate CUDA stream to
911-
# overlap with the main fused MoE kernel.
912-
use_shared_experts_stream = (
913-
self.shared_experts is not None
914-
and self.shared_experts_stream is not None
915-
and hidden_states.is_cuda
916-
and (
917-
hidden_states.shape[0]
918-
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
919-
)
920-
)
921-
922-
hidden_states_clone: torch.Tensor | None = None
923-
if use_shared_experts_stream and self.shared_experts_stream is not None:
924-
# TODO: Optimize this (complicated)
925-
# Note: this clone adds overhead but is required
926-
# for correctness with multiple CUDA streams and CUDA graph capture.
927-
hidden_states_clone = hidden_states.clone()
928-
# record that the clone will be used by the separate stream so its
929-
# lifetime is correctly tracked.
930-
hidden_states_clone.record_stream(self.shared_experts_stream)
931-
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
932-
933-
return use_shared_experts_stream, hidden_states_clone
934-
935904
def _prepare(
936905
self,
937906
hidden_states: torch.Tensor,
@@ -1119,30 +1088,12 @@ def _finalize(
11191088
topk_weights: torch.Tensor,
11201089
topk_ids: torch.Tensor,
11211090
apply_router_weight_on_input: bool,
1122-
hidden_states_clone: torch.Tensor | None = None,
1123-
use_shared_experts_stream: bool = False,
11241091
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
11251092
"""
11261093
The _finalize method is a wrapper around self.prepare_finalize.finalize
11271094
that handles DBO, async and shared expert overlap.
11281095
"""
1129-
1130-
def maybe_run_shared_experts() -> torch.Tensor | None:
1131-
if self.shared_experts is None:
1132-
return None
1133-
1134-
if (
1135-
not use_shared_experts_stream
1136-
or self.shared_experts_stream is not None
1137-
and (not hidden_states.is_cuda or not torch.cuda.is_available())
1138-
):
1139-
# fall back to running on the current stream
1140-
return self.shared_experts(hidden_states)
1141-
1142-
assert hidden_states_clone is not None
1143-
# launch shared experts on the dedicated stream.
1144-
with torch.cuda.stream(self.shared_experts_stream):
1145-
return self.shared_experts(hidden_states_clone)
1096+
shared_output: torch.Tensor | None = None
11461097

11471098
if not self.prepare_finalize.supports_async():
11481099
assert not dbo_enabled()
@@ -1155,7 +1106,8 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11551106
apply_router_weight_on_input,
11561107
self.fused_experts.finalize_weight_and_reduce_impl(),
11571108
)
1158-
shared_output = maybe_run_shared_experts()
1109+
if self.shared_experts is not None:
1110+
shared_output = self.shared_experts(hidden_states)
11591111
else:
11601112
finalize_ret = self.prepare_finalize.finalize_async(
11611113
output,
@@ -1165,8 +1117,8 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11651117
apply_router_weight_on_input,
11661118
self.fused_experts.finalize_weight_and_reduce_impl(),
11671119
)
1168-
1169-
shared_output = maybe_run_shared_experts()
1120+
if self.shared_experts is not None:
1121+
shared_output = self.shared_experts(hidden_states)
11701122

11711123
# TODO(lucas): refactor this in the alternative schedules followup
11721124
# currently unpack if we have hook + receiver pair or just
@@ -1189,28 +1141,12 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11891141

11901142
receiver()
11911143

1192-
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
1193-
11941144
if self.shared_experts is None:
11951145
return output
11961146
else:
11971147
assert shared_output is not None
11981148
return shared_output, output
11991149

1200-
def _wait_for_shared_experts_stream(
1201-
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
1202-
) -> None:
1203-
# ensure that any work enqueued on the shared_experts_stream is
1204-
# completed before the shared_output tensor is consumed
1205-
if (
1206-
self.shared_experts is not None
1207-
and use_shared_experts_stream
1208-
and self.shared_experts_stream is not None
1209-
and hidden_states.is_cuda
1210-
and current_platform.is_cuda()
1211-
):
1212-
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
1213-
12141150
def forward(
12151151
self,
12161152
hidden_states: torch.Tensor,
@@ -1257,10 +1193,6 @@ def forward(
12571193
else:
12581194
output = torch.zeros_like(hidden_states)
12591195

1260-
use_shared_experts_stream, hidden_states_clone = (
1261-
self._maybe_setup_shared_experts_stream(hidden_states)
1262-
)
1263-
12641196
local_num_experts = w1.size(0)
12651197
if global_num_experts == -1:
12661198
global_num_experts = local_num_experts
@@ -1297,6 +1229,4 @@ def forward(
12971229
topk_weights,
12981230
topk_ids,
12991231
apply_router_weight_on_input,
1300-
hidden_states_clone=hidden_states_clone,
1301-
use_shared_experts_stream=use_shared_experts_stream,
13021232
)

0 commit comments

Comments
 (0)