Skip to content

Commit 7ef40bb

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranathmgoin
authored
[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (vllm-project#25488)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 767cbb0 commit 7ef40bb

File tree

9 files changed

+264
-101
lines changed

9 files changed

+264
-101
lines changed

docs/design/moe_kernel_features.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
9393
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
9494
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
9595
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
96+
97+
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
9698
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
9799
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
98100
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
@@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor
114116

115117
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
116118
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
117-
| deepep_high_throughput,</br>pplx | `DeepEPHTPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8` |
118-
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8` |
119+
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
120+
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
119121
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def apply(
303303

304304
assert w2.size(1) == K
305305

306-
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
306+
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
307307
hidden_states, w1, w2, topk_ids)
308308

309309
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def apply(
712712
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
713713
apply_router_weight_on_input: bool,
714714
):
715-
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
715+
e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
716716
n = w2.shape[2] * 2
717717

718718
run_cutlass_moe_fp4(

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def apply(
906906

907907
expert_num_tokens = expert_tokens_meta.expert_num_tokens
908908

909-
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
909+
E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
910910
hidden_states, w1, w2, topk_ids)
911911

912912
assert w1.size(0) == E

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 164 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
from typing import Optional
55

66
import torch
7+
from typing_extensions import override
78

89
import vllm._custom_ops as ops
10+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11+
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
912
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
13+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
14+
TopKWeightAndReduceNoOP)
15+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1016
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
11-
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
17+
marlin_make_workspace_new, marlin_moe_intermediate_size,
18+
maybe_warn_marlin_atomic_add)
1219
from vllm.scalar_type import ScalarType, scalar_types
1320
from vllm.utils import direct_register_custom_op
1421

@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
2027
bias2: Optional[torch.Tensor],
2128
w1_scale: torch.Tensor,
2229
w2_scale: torch.Tensor,
23-
gating_output: torch.Tensor,
30+
gating_output: Optional[torch.Tensor],
2431
topk_weights: torch.Tensor,
2532
topk_ids: torch.Tensor,
2633
quant_type_id: int,
@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
3744
w1_zeros: Optional[torch.Tensor] = None,
3845
w2_zeros: Optional[torch.Tensor] = None,
3946
workspace: Optional[torch.Tensor] = None,
47+
intermediate_cache13: Optional[torch.Tensor] = None,
48+
intermediate_cache2: Optional[torch.Tensor] = None,
4049
is_k_full: bool = True,
50+
output: Optional[torch.Tensor] = None,
4151
inplace: bool = False) -> torch.Tensor:
4252
"""
4353
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
4959
- w2 (torch.Tensor): The second set of expert weights.
5060
- w1_scale (torch.Tensor): Scale to be used for w1.
5161
- w2_scale (torch.Tensor): Scale to be used for w2.
52-
- gating_output (torch.Tensor): The output of the gating operation
53-
(before softmax).
62+
- gating_output (Optional[torch.Tensor]): The output of the gating
63+
operation (before softmax).
5464
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
5565
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
5666
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
7888
num_bits = 4 if quant_type in bit4_scalar_types else 8
7989

8090
# Check constraints.
81-
assert hidden_states.shape[0] == gating_output.shape[
82-
0], "Number of tokens mismatch"
91+
if gating_output is not None:
92+
assert hidden_states.shape[0] == gating_output.shape[
93+
0], "Number of tokens mismatch"
8394
assert hidden_states.shape[
8495
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
8596
assert hidden_states.shape[1] == w2.shape[2] // (
@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
93104

94105
M, K = hidden_states.shape
95106
E = w1.shape[0]
96-
N = w2.shape[1] * 16
107+
N = marlin_moe_intermediate_size(w1, w2)
97108
topk = topk_ids.shape[1]
98109

99110
# M block size selection logic
@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
111122
if workspace is None:
112123
workspace = marlin_make_workspace_new(hidden_states.device, 4)
113124

114-
intermediate_cache2 = torch.empty(
115-
(M * topk_ids.shape[1], N),
116-
device=hidden_states.device,
117-
dtype=hidden_states.dtype,
118-
)
119-
intermediate_cache13 = torch.empty(
120-
(M * topk_ids.shape[1] * max(2 * N, K), ),
121-
device=hidden_states.device,
122-
dtype=hidden_states.dtype,
123-
)
124-
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
125-
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
126-
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
127-
intermediate_cache3 = intermediate_cache3.view(-1, K)
125+
if intermediate_cache2 is None:
126+
intermediate_cache2 = torch.empty(
127+
(M * topk, N),
128+
device=hidden_states.device,
129+
dtype=hidden_states.dtype,
130+
)
131+
132+
if intermediate_cache13 is None:
133+
intermediate_cache13 = torch.empty(
134+
(M * topk * max(2 * N, K), ),
135+
device=hidden_states.device,
136+
dtype=hidden_states.dtype,
137+
)
138+
139+
intermediate_cache1 = _resize_cache(intermediate_cache13,
140+
(M * topk, 2 * N))
141+
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
142+
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
128143

129144
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
130145
use_atomic_add = hidden_states.dtype == torch.half or \
@@ -200,18 +215,17 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
200215
use_fp32_reduce=True,
201216
is_zp_float=False).view(-1, topk, K)
202217

203-
output = hidden_states if inplace else torch.empty_like(hidden_states)
204-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
205-
dim=1,
206-
out=output)
218+
if output is None:
219+
output = hidden_states if inplace else torch.empty_like(hidden_states)
220+
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
207221

208222

209223
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
210224
w1: torch.Tensor,
211225
w2: torch.Tensor,
212226
w1_scale: torch.Tensor,
213227
w2_scale: torch.Tensor,
214-
gating_output: torch.Tensor,
228+
gating_output: Optional[torch.Tensor],
215229
topk_weights: torch.Tensor,
216230
topk_ids: torch.Tensor,
217231
quant_type_id: int,
@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
227241
w1_zeros: Optional[torch.Tensor] = None,
228242
w2_zeros: Optional[torch.Tensor] = None,
229243
workspace: Optional[torch.Tensor] = None,
244+
intermediate_cache13: Optional[torch.Tensor] = None,
245+
intermediate_cache2: Optional[torch.Tensor] = None,
230246
is_k_full: bool = True,
247+
output: Optional[torch.Tensor] = None,
231248
inplace: bool = False) -> torch.Tensor:
232249
return torch.empty_like(hidden_states)
233250

@@ -237,3 +254,124 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
237254
op_func=fused_marlin_moe,
238255
fake_impl=fused_marlin_moe_fake,
239256
)
257+
258+
259+
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
260+
261+
def __init__(self, quant_config: FusedMoEQuantConfig):
262+
# TODO (varun) : Enable activation quantization
263+
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
264+
super().__init__(quant_config)
265+
266+
@override
267+
def moe_problem_size(
268+
self,
269+
a1: torch.Tensor,
270+
w1: torch.Tensor,
271+
w2: torch.Tensor,
272+
topk_ids: torch.Tensor,
273+
) -> tuple[int, int, int, int, int]:
274+
assert w1.dim() == 3 and w2.dim() == 3
275+
276+
E = w1.size(0)
277+
K = a1.size(-1)
278+
N = marlin_moe_intermediate_size(w1, w2)
279+
280+
if a1.dim() == 2:
281+
# Make sure we are using the correct a1 (pre-permute).
282+
assert topk_ids.size(0) == a1.size(0), \
283+
f"{topk_ids.size(0)} != {a1.size(0)}"
284+
M = a1.size(0)
285+
else:
286+
assert a1.dim() == 3
287+
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
288+
M = a1.size(1) # This is max_num_tokens
289+
290+
assert topk_ids.dim() == 2
291+
topk = topk_ids.size(1)
292+
293+
return E, M, N, K, topk
294+
295+
def supports_expert_map(self) -> bool:
296+
return True
297+
298+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
299+
return TopKWeightAndReduceNoOP()
300+
301+
@property
302+
def activation_formats(
303+
self
304+
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
305+
return (mk.FusedMoEActivationFormat.Standard,
306+
mk.FusedMoEActivationFormat.Standard)
307+
308+
def supports_chunking(self) -> bool:
309+
return True
310+
311+
def workspace_shapes(
312+
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
313+
topk: int, global_num_experts: int, local_num_experts: int,
314+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
315+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
316+
# Modular Kernel provisions output buffer from workspace1. However in
317+
# the fused_marlin_moe() function, the final torch.sum(), is defined
318+
# essentially as,
319+
# `torch.sum(workspace1, dim=1, out=output)`
320+
# Having overlapping input and output tensors for torch.sum seems
321+
# error prone and depends on how the torch.sum is implemented.
322+
# For this reason we swap let the output buffer provision from
323+
# workspace2.
324+
325+
# Workspace/IntermediateCache allocation matching fused_marlin_moe()
326+
#workspace1 = (M * topk * max(2 * N, K),)
327+
#workspace2 = (M * topk, N)
328+
329+
# Workspace/IntermediateCache allocation accounting for output buffer
330+
# provisioning
331+
workspace1 = (M * topk, max(N, K))
332+
workspace2 = (M * topk * max(2 * N, K), )
333+
output = (M, K)
334+
335+
return (workspace1, workspace2, output, a.dtype)
336+
337+
def apply(
338+
self,
339+
output: torch.Tensor,
340+
hidden_states: torch.Tensor,
341+
w1: torch.Tensor,
342+
w2: torch.Tensor,
343+
topk_weights: torch.Tensor,
344+
topk_ids: torch.Tensor,
345+
activation: str,
346+
global_num_experts: int,
347+
expert_map: Optional[torch.Tensor],
348+
a1q_scale: Optional[torch.Tensor],
349+
a2_scale: Optional[torch.Tensor],
350+
workspace13: torch.Tensor,
351+
workspace2: torch.Tensor,
352+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
353+
apply_router_weight_on_input: bool,
354+
):
355+
assert self.w1_scale is not None
356+
assert self.w2_scale is not None
357+
return fused_marlin_moe(
358+
hidden_states=hidden_states,
359+
w1=w1,
360+
w2=w2,
361+
bias1=self.w1_bias,
362+
bias2=self.w2_bias,
363+
w1_scale=self.w1_scale,
364+
w2_scale=self.w2_scale,
365+
gating_output=None,
366+
topk_weights=topk_weights,
367+
topk_ids=topk_ids,
368+
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
369+
apply_router_weight_on_input=apply_router_weight_on_input,
370+
global_num_experts=global_num_experts,
371+
activation=activation,
372+
expert_map=expert_map,
373+
output=output,
374+
# Workspaces are swapped in workspace_shapes() to account for proper
375+
# output buffer allocation. Please refer to workspace_shapes().
376+
intermediate_cache13=workspace2,
377+
intermediate_cache2=workspace13)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,7 @@ def apply(
17801780
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
17811781
]
17821782

1783-
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
1783+
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
17841784
hidden_states, w1, w2, topk_ids)
17851785

17861786
if global_num_experts == -1:

0 commit comments

Comments
 (0)