Skip to content

Commit 3778673

Browse files
yewentao256Robert Shawrobertgshaw2-redhat
authored
[Feat] Refactor for parallel_config in FusedMoEModularKernel (vllm-project#30282)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
1 parent b337647 commit 3778673

File tree

8 files changed

+32
-27
lines changed

8 files changed

+32
-27
lines changed

tests/kernels/moe/modular_kernel_tools/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def next_power_of_2(x):
594594
)
595595

596596
modular_kernel = mk.FusedMoEModularKernel(
597-
prepare_finalize=prepare_finalize, fused_experts=fused_experts
597+
prepare_finalize=prepare_finalize,
598+
fused_experts=fused_experts,
598599
)
599600

600601
return modular_kernel

tests/kernels/moe/test_flashinfer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
import torch
77

8+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
89
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
910
from vllm.model_executor.layers.fused_moe.config import (
1011
FusedMoEQuantConfig,
@@ -107,6 +108,19 @@ def make_moe_tensors_8bit(
107108
layer.w2_input_scale = a2_scale
108109
layer.w13_weight_scale = w13_weight_scale
109110
layer.w2_weight_scale = w2_weight_scale
111+
# Setup dummy config.
112+
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
113+
tp_size=1,
114+
pcp_size=1,
115+
dp_size=1,
116+
ep_size=1,
117+
tp_rank=1,
118+
pcp_rank=1,
119+
dp_rank=1,
120+
ep_rank=1,
121+
use_ep=False,
122+
all2all_backend="naive",
123+
)
110124

111125
register_moe_scaling_factors(layer)
112126

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,6 @@ def cutlass_moe_fp8(
460460
expert_map: torch.Tensor | None = None,
461461
apply_router_weight_on_input: bool = False,
462462
global_num_experts: int = -1,
463-
parallel_config=None,
464463
) -> torch.Tensor:
465464
"""
466465
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -538,7 +537,6 @@ def cutlass_moe_fp8(
538537
c_strides2=c_strides2,
539538
quant_config=quant_config,
540539
),
541-
parallel_config=parallel_config,
542540
)
543541

544542
return fn(

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def deep_gemm_moe_fp8(
293293
expert_map: torch.Tensor | None = None,
294294
a1_scale: torch.Tensor | None = None,
295295
a2_scale: torch.Tensor | None = None,
296-
apply_router_weight_on_input=False,
296+
apply_router_weight_on_input: bool = False,
297297
) -> torch.Tensor:
298298
"""
299299
This function computes a a8w8-quantized Mixture of Experts (MoE) layer

vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,14 @@ def make(
4343
prepare_finalize: FusedMoEPrepareAndFinalize,
4444
shared_experts: torch.nn.Module | None,
4545
) -> "FusedMoEModularMethod":
46-
parallel_config = getattr(
47-
getattr(moe_layer, "vllm_config", None),
48-
"parallel_config",
49-
None,
50-
)
5146
return FusedMoEModularMethod(
5247
old_quant_method,
5348
FusedMoEModularKernel(
5449
prepare_finalize,
5550
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
5651
shared_experts,
5752
getattr(moe_layer, "shared_experts_stream", None),
58-
parallel_config=parallel_config,
53+
moe_parallel_config=moe_layer.moe_parallel_config,
5954
),
6055
)
6156

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import torch
1111

1212
import vllm.envs as envs
13-
from vllm.config import ParallelConfig, get_current_vllm_config
1413
from vllm.forward_context import get_forward_context, is_forward_context_available
1514
from vllm.logger import init_logger
16-
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
15+
from vllm.model_executor.layers.fused_moe.config import (
16+
FusedMoEParallelConfig,
17+
FusedMoEQuantConfig,
18+
)
1719
from vllm.model_executor.layers.fused_moe.utils import (
1820
_resize_cache,
1921
count_expert_num_tokens,
@@ -681,20 +683,23 @@ def __init__(
681683
fused_experts: FusedMoEPermuteExpertsUnpermute,
682684
shared_experts: torch.nn.Module | None = None,
683685
shared_experts_stream: torch.cuda.Stream | None = None,
684-
parallel_config: ParallelConfig | None = None,
686+
moe_parallel_config: FusedMoEParallelConfig | None = None,
685687
):
686688
super().__init__()
687689
self.prepare_finalize = prepare_finalize
688690
self.fused_experts = fused_experts
689691
self.shared_experts = shared_experts
690692
self.shared_experts_stream = shared_experts_stream
691693

692-
# cache whether this worker is using DP+EP
693-
if parallel_config is None:
694-
parallel_config = get_current_vllm_config().parallel_config
694+
# prefer an explicit FusedMoEParallelConfig when available (from
695+
# FusedMoE layers / tests).
696+
# if not provided, assume this kernel is
697+
# running in a non-DP+EP context
698+
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
695699
self.is_dp_ep = (
696-
parallel_config.data_parallel_size > 1
697-
and parallel_config.enable_expert_parallel
700+
moe_parallel_config is not None
701+
and moe_parallel_config.dp_size > 1
702+
and moe_parallel_config.use_ep
698703
)
699704

700705
self._post_init_setup()

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,9 +1266,6 @@ def apply(
12661266
ab_strides2=self.ab_strides2,
12671267
c_strides1=self.c_strides1,
12681268
c_strides2=self.ab_strides1_c_strides2,
1269-
parallel_config=getattr(
1270-
getattr(layer, "vllm_config", None), "parallel_config", None
1271-
),
12721269
)
12731270

12741271
else:

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8(
247247
assert quant_config is not None
248248

249249
# Construct modular kernel with block-scale support when requested.
250-
parallel_config = getattr(
251-
getattr(layer, "vllm_config", None),
252-
"parallel_config",
253-
None,
254-
)
255250
fused_experts = mk.FusedMoEModularKernel(
256251
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
257252
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
@@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8(
262257
out_dtype=hidden_states.dtype,
263258
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
264259
),
265-
parallel_config=parallel_config,
260+
moe_parallel_config=layer.moe_parallel_config,
266261
)
267262

268263
return fused_experts(

0 commit comments

Comments
 (0)