Skip to content

Commit c97234c

Browse files
elizabethtclaudemgoin
authored
fix(mxfp4): Disable monolithic path for TRITON backend with EP (vllm-project#34270)
Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent b188bab commit c97234c

File tree

2 files changed

+225
-5
lines changed

2 files changed

+225
-5
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests that triton_kernel_moe_forward correctly applies expert_map
5+
remapping when expert parallelism (EP) is enabled.
6+
7+
Previously, legacy_routing was always used and it produced routing data
8+
with global expert IDs that didn't correspond to local weight indices,
9+
causing illegal memory access with EP. The fix splits routing: when
10+
expert_map is provided, topk selection is performed first, expert_map is
11+
applied to remap global→local IDs, and make_routing_data builds routing
12+
structures from the local IDs.
13+
"""
14+
15+
from unittest.mock import MagicMock, patch
16+
17+
import pytest
18+
import torch
19+
20+
from vllm.model_executor.layers.quantization.mxfp4 import (
21+
Mxfp4Backend,
22+
Mxfp4MoEMethod,
23+
)
24+
25+
26+
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
27+
"""Create a mock FusedMoEConfig with the given EP size."""
28+
parallel_config = MagicMock()
29+
parallel_config.ep_size = ep_size
30+
31+
moe_config = MagicMock()
32+
moe_config.ep_size = ep_size
33+
moe_config.is_lora_enabled = False
34+
moe_config.moe_parallel_config = parallel_config
35+
return moe_config
36+
37+
38+
class TestMxfp4TritonIsMonolithic:
39+
"""Verify that is_monolithic is always True for the TRITON backend,
40+
regardless of EP size, since triton_kernel_moe_forward now handles
41+
expert_map remapping internally."""
42+
43+
@pytest.mark.parametrize(
44+
"backend,ep_size,expected_monolithic",
45+
[
46+
# TRITON is always monolithic (handles EP via expert_map remapping)
47+
(Mxfp4Backend.TRITON, 1, True),
48+
(Mxfp4Backend.TRITON, 2, True),
49+
(Mxfp4Backend.TRITON, 4, True),
50+
# SM100 backends are always monolithic
51+
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
52+
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
53+
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
54+
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
55+
# MARLIN is never monolithic
56+
(Mxfp4Backend.MARLIN, 1, False),
57+
(Mxfp4Backend.MARLIN, 2, False),
58+
],
59+
ids=[
60+
"triton-no-ep",
61+
"triton-ep2",
62+
"triton-ep4",
63+
"sm100-trtllm-no-ep",
64+
"sm100-trtllm-ep2",
65+
"sm100-bf16-no-ep",
66+
"sm100-bf16-ep2",
67+
"marlin-no-ep",
68+
"marlin-ep2",
69+
],
70+
)
71+
@patch(
72+
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
73+
)
74+
@patch(
75+
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
76+
)
77+
def test_is_monolithic(
78+
self,
79+
mock_get_config,
80+
mock_get_backend,
81+
backend,
82+
ep_size,
83+
expected_monolithic,
84+
):
85+
"""is_monolithic should be True for TRITON regardless of EP size."""
86+
mock_get_backend.return_value = backend
87+
88+
mock_compilation_config = MagicMock()
89+
mock_compilation_config.max_cudagraph_capture_size = 1024
90+
mock_vllm_config = MagicMock()
91+
mock_vllm_config.compilation_config = mock_compilation_config
92+
mock_get_config.return_value = mock_vllm_config
93+
94+
moe_config = _make_mock_moe_config(ep_size=ep_size)
95+
method = Mxfp4MoEMethod(moe_config)
96+
97+
assert method.is_monolithic == expected_monolithic, (
98+
f"Expected is_monolithic={expected_monolithic} for "
99+
f"backend={backend.name}, ep_size={ep_size}, "
100+
f"but got {method.is_monolithic}."
101+
)
102+
103+
104+
class TestTritonMoeForwardExpertMap:
105+
"""Test that triton_kernel_moe_forward applies expert_map remapping
106+
when expert_map is provided (EP active)."""
107+
108+
@pytest.mark.parametrize("expert_map_present", [False, True])
109+
def test_routing_path_selection(self, expert_map_present):
110+
"""Verify that the EP-aware routing path is taken when expert_map
111+
is present, and the legacy_routing path is taken otherwise."""
112+
113+
device = "cuda" if torch.cuda.is_available() else "cpu"
114+
# This is a structural test: we mock the routing functions to
115+
# verify the correct path is exercised.
116+
mock_expert_map = (
117+
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
118+
)
119+
120+
with (
121+
patch(
122+
"vllm.model_executor.layers.fused_moe."
123+
"gpt_oss_triton_kernels_moe.legacy_routing"
124+
) as mock_legacy,
125+
patch("triton_kernels.topk.topk") as mock_topk,
126+
patch(
127+
"vllm.model_executor.layers.fused_moe."
128+
"gpt_oss_triton_kernels_moe.make_routing_data"
129+
) as mock_make_routing,
130+
patch(
131+
"vllm.model_executor.layers.fused_moe."
132+
"gpt_oss_triton_kernels_moe.triton_kernel_fused_experts"
133+
) as mock_fused_experts,
134+
):
135+
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
136+
triton_kernel_moe_forward,
137+
)
138+
139+
# Set up return values
140+
mock_routing_data = MagicMock()
141+
mock_gather = MagicMock()
142+
mock_scatter = MagicMock()
143+
144+
if expert_map_present:
145+
sparse_result = MagicMock()
146+
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
147+
sparse_result.vals = torch.tensor([[0.6, 0.4]])
148+
mock_topk.return_value = sparse_result
149+
mock_make_routing.return_value = (
150+
mock_routing_data,
151+
mock_gather,
152+
mock_scatter,
153+
)
154+
else:
155+
mock_legacy.return_value = (
156+
mock_routing_data,
157+
mock_gather,
158+
mock_scatter,
159+
)
160+
161+
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
162+
163+
hidden = torch.randn((1, 8), device=device)
164+
w1 = torch.randn((2, 8, 16), device=device)
165+
w2 = torch.randn((2, 8, 8), device=device)
166+
logits = torch.randn((1, 4), device=device)
167+
168+
triton_kernel_moe_forward(
169+
hidden_states=hidden,
170+
w1=w1,
171+
w2=w2,
172+
gating_output=logits,
173+
topk=2,
174+
renormalize=True,
175+
expert_map=mock_expert_map,
176+
)
177+
178+
if expert_map_present:
179+
# EP path: should use topk + make_routing_data, NOT
180+
# legacy_routing
181+
mock_topk.assert_called_once()
182+
mock_make_routing.assert_called_once()
183+
mock_legacy.assert_not_called()
184+
# expert_map should be None in the fused_experts call
185+
# (already applied)
186+
call_kwargs = mock_fused_experts.call_args
187+
assert call_kwargs[1].get("expert_map") is None or (
188+
len(call_kwargs[0]) > 0
189+
)
190+
else:
191+
# Non-EP path: should use legacy_routing
192+
mock_legacy.assert_called_once()
193+
mock_topk.assert_not_called()
194+
mock_make_routing.assert_not_called()

vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,35 @@ def triton_kernel_moe_forward(
179179
global_num_experts: int = -1,
180180
expert_map: torch.Tensor | None = None,
181181
) -> torch.Tensor:
182-
routing_data, gather_idx, scatter_idx = legacy_routing(
183-
gating_output, topk, sm_first=not renormalize
184-
)
182+
if expert_map is not None:
183+
# With expert parallelism, legacy_routing produces routing data
184+
# using global expert IDs which don't correspond to local weight
185+
# indices. Split the routing into topk selection + expert_map
186+
# remapping + local routing data construction (matching the
187+
# approach used by OAITritonExperts.apply).
188+
from triton_kernels.topk import topk as topk_fn
189+
190+
sm_first = not renormalize
191+
logits = gating_output
192+
if sm_first:
193+
logits = torch.softmax(logits, dim=-1)
194+
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
195+
# sparse_logits.indx contains global expert IDs – remap to local.
196+
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
197+
topk_weights = sparse_logits.vals
198+
local_num_experts = w1.size(0)
199+
routing_data, gather_idx, scatter_idx = make_routing_data(
200+
topk_ids, topk_weights, local_num_experts
201+
)
202+
# expert_map already applied; pass None downstream.
203+
effective_expert_map = None
204+
effective_global_num_experts = local_num_experts
205+
else:
206+
routing_data, gather_idx, scatter_idx = legacy_routing(
207+
gating_output, topk, sm_first=not renormalize
208+
)
209+
effective_expert_map = expert_map
210+
effective_global_num_experts = global_num_experts
185211

186212
output = torch.empty_like(hidden_states)
187213

@@ -197,8 +223,8 @@ def triton_kernel_moe_forward(
197223
activation=activation,
198224
quant_config=quant_config,
199225
apply_router_weight_on_input=apply_router_weight_on_input,
200-
global_num_experts=global_num_experts,
201-
expert_map=expert_map,
226+
global_num_experts=effective_global_num_experts,
227+
expert_map=effective_expert_map,
202228
)
203229

204230

0 commit comments

Comments
 (0)