4
4
from typing import Optional
5
5
6
6
import torch
7
+ from typing_extensions import override
7
8
8
9
import vllm ._custom_ops as ops
10
+ import vllm .model_executor .layers .fused_moe .modular_kernel as mk
11
+ from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
9
12
from vllm .model_executor .layers .fused_moe .fused_moe import moe_align_block_size
13
+ from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
14
+ TopKWeightAndReduceNoOP )
15
+ from vllm .model_executor .layers .fused_moe .utils import _resize_cache
10
16
from vllm .model_executor .layers .quantization .utils .marlin_utils import (
11
- marlin_make_workspace_new , maybe_warn_marlin_atomic_add )
17
+ marlin_make_workspace_new , marlin_moe_intermediate_size ,
18
+ maybe_warn_marlin_atomic_add )
12
19
from vllm .scalar_type import ScalarType , scalar_types
13
20
from vllm .utils import direct_register_custom_op
14
21
@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
20
27
bias2 : Optional [torch .Tensor ],
21
28
w1_scale : torch .Tensor ,
22
29
w2_scale : torch .Tensor ,
23
- gating_output : torch .Tensor ,
30
+ gating_output : Optional [ torch .Tensor ] ,
24
31
topk_weights : torch .Tensor ,
25
32
topk_ids : torch .Tensor ,
26
33
quant_type_id : int ,
@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
37
44
w1_zeros : Optional [torch .Tensor ] = None ,
38
45
w2_zeros : Optional [torch .Tensor ] = None ,
39
46
workspace : Optional [torch .Tensor ] = None ,
47
+ intermediate_cache13 : Optional [torch .Tensor ] = None ,
48
+ intermediate_cache2 : Optional [torch .Tensor ] = None ,
40
49
is_k_full : bool = True ,
50
+ output : Optional [torch .Tensor ] = None ,
41
51
inplace : bool = False ) -> torch .Tensor :
42
52
"""
43
53
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
49
59
- w2 (torch.Tensor): The second set of expert weights.
50
60
- w1_scale (torch.Tensor): Scale to be used for w1.
51
61
- w2_scale (torch.Tensor): Scale to be used for w2.
52
- - gating_output (torch.Tensor): The output of the gating operation
53
- (before softmax).
62
+ - gating_output (Optional[ torch.Tensor] ): The output of the gating
63
+ operation (before softmax).
54
64
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
55
65
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
56
66
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
78
88
num_bits = 4 if quant_type in bit4_scalar_types else 8
79
89
80
90
# Check constraints.
81
- assert hidden_states .shape [0 ] == gating_output .shape [
82
- 0 ], "Number of tokens mismatch"
91
+ if gating_output is not None :
92
+ assert hidden_states .shape [0 ] == gating_output .shape [
93
+ 0 ], "Number of tokens mismatch"
83
94
assert hidden_states .shape [
84
95
1 ] == w1 .shape [1 ] * 16 , "Hidden size mismatch w1"
85
96
assert hidden_states .shape [1 ] == w2 .shape [2 ] // (
@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
93
104
94
105
M , K = hidden_states .shape
95
106
E = w1 .shape [0 ]
96
- N = w2 . shape [ 1 ] * 16
107
+ N = marlin_moe_intermediate_size ( w1 , w2 )
97
108
topk = topk_ids .shape [1 ]
98
109
99
110
# M block size selection logic
@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
111
122
if workspace is None :
112
123
workspace = marlin_make_workspace_new (hidden_states .device , 4 )
113
124
114
- intermediate_cache2 = torch .empty (
115
- (M * topk_ids .shape [1 ], N ),
116
- device = hidden_states .device ,
117
- dtype = hidden_states .dtype ,
118
- )
119
- intermediate_cache13 = torch .empty (
120
- (M * topk_ids .shape [1 ] * max (2 * N , K ), ),
121
- device = hidden_states .device ,
122
- dtype = hidden_states .dtype ,
123
- )
124
- intermediate_cache1 = intermediate_cache13 [:M * topk_ids .shape [1 ] * 2 * N ]
125
- intermediate_cache1 = intermediate_cache1 .view (- 1 , 2 * N )
126
- intermediate_cache3 = intermediate_cache13 [:M * topk_ids .shape [1 ] * K ]
127
- intermediate_cache3 = intermediate_cache3 .view (- 1 , K )
125
+ if intermediate_cache2 is None :
126
+ intermediate_cache2 = torch .empty (
127
+ (M * topk , N ),
128
+ device = hidden_states .device ,
129
+ dtype = hidden_states .dtype ,
130
+ )
131
+
132
+ if intermediate_cache13 is None :
133
+ intermediate_cache13 = torch .empty (
134
+ (M * topk * max (2 * N , K ), ),
135
+ device = hidden_states .device ,
136
+ dtype = hidden_states .dtype ,
137
+ )
138
+
139
+ intermediate_cache1 = _resize_cache (intermediate_cache13 ,
140
+ (M * topk , 2 * N ))
141
+ intermediate_cache3 = _resize_cache (intermediate_cache13 , (M * topk , K ))
142
+ intermediate_cache2 = _resize_cache (intermediate_cache2 , (M * topk , N ))
128
143
129
144
maybe_warn_marlin_atomic_add (hidden_states .device , hidden_states .dtype )
130
145
use_atomic_add = hidden_states .dtype == torch .half or \
@@ -200,18 +215,17 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
200
215
use_fp32_reduce = True ,
201
216
is_zp_float = False ).view (- 1 , topk , K )
202
217
203
- output = hidden_states if inplace else torch .empty_like (hidden_states )
204
- return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
205
- dim = 1 ,
206
- out = output )
218
+ if output is None :
219
+ output = hidden_states if inplace else torch .empty_like (hidden_states )
220
+ return torch .sum (intermediate_cache3 .view (- 1 , topk , K ), dim = 1 , out = output )
207
221
208
222
209
223
def fused_marlin_moe_fake (hidden_states : torch .Tensor ,
210
224
w1 : torch .Tensor ,
211
225
w2 : torch .Tensor ,
212
226
w1_scale : torch .Tensor ,
213
227
w2_scale : torch .Tensor ,
214
- gating_output : torch .Tensor ,
228
+ gating_output : Optional [ torch .Tensor ] ,
215
229
topk_weights : torch .Tensor ,
216
230
topk_ids : torch .Tensor ,
217
231
quant_type_id : int ,
@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
227
241
w1_zeros : Optional [torch .Tensor ] = None ,
228
242
w2_zeros : Optional [torch .Tensor ] = None ,
229
243
workspace : Optional [torch .Tensor ] = None ,
244
+ intermediate_cache13 : Optional [torch .Tensor ] = None ,
245
+ intermediate_cache2 : Optional [torch .Tensor ] = None ,
230
246
is_k_full : bool = True ,
247
+ output : Optional [torch .Tensor ] = None ,
231
248
inplace : bool = False ) -> torch .Tensor :
232
249
return torch .empty_like (hidden_states )
233
250
@@ -237,3 +254,124 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
237
254
op_func = fused_marlin_moe ,
238
255
fake_impl = fused_marlin_moe_fake ,
239
256
)
257
+
258
+
259
+ class MarlinExperts (mk .FusedMoEPermuteExpertsUnpermute ):
260
+
261
+ def __init__ (self , quant_config : FusedMoEQuantConfig ):
262
+ # TODO (varun) : Enable activation quantization
263
+ assert quant_config .use_mxfp4_w4a16 , "Supports only mxfp4_w4a16"
264
+ super ().__init__ (quant_config )
265
+
266
+ @override
267
+ def moe_problem_size (
268
+ self ,
269
+ a1 : torch .Tensor ,
270
+ w1 : torch .Tensor ,
271
+ w2 : torch .Tensor ,
272
+ topk_ids : torch .Tensor ,
273
+ ) -> tuple [int , int , int , int , int ]:
274
+ assert w1 .dim () == 3 and w2 .dim () == 3
275
+
276
+ E = w1 .size (0 )
277
+ K = a1 .size (- 1 )
278
+ N = marlin_moe_intermediate_size (w1 , w2 )
279
+
280
+ if a1 .dim () == 2 :
281
+ # Make sure we are using the correct a1 (pre-permute).
282
+ assert topk_ids .size (0 ) == a1 .size (0 ), \
283
+ f"{ topk_ids .size (0 )} != { a1 .size (0 )} "
284
+ M = a1 .size (0 )
285
+ else :
286
+ assert a1 .dim () == 3
287
+ assert a1 .size (0 ) == E , f"{ a1 .size (0 )} == { E } "
288
+ M = a1 .size (1 ) # This is max_num_tokens
289
+
290
+ assert topk_ids .dim () == 2
291
+ topk = topk_ids .size (1 )
292
+
293
+ return E , M , N , K , topk
294
+
295
+ def supports_expert_map (self ) -> bool :
296
+ return True
297
+
298
+ def finalize_weight_and_reduce_impl (self ) -> mk .TopKWeightAndReduce :
299
+ return TopKWeightAndReduceNoOP ()
300
+
301
+ @property
302
+ def activation_formats (
303
+ self
304
+ ) -> tuple [mk .FusedMoEActivationFormat , mk .FusedMoEActivationFormat ]:
305
+ return (mk .FusedMoEActivationFormat .Standard ,
306
+ mk .FusedMoEActivationFormat .Standard )
307
+
308
+ def supports_chunking (self ) -> bool :
309
+ return True
310
+
311
+ def workspace_shapes (
312
+ self , a : torch .Tensor , aq : torch .Tensor , M : int , N : int , K : int ,
313
+ topk : int , global_num_experts : int , local_num_experts : int ,
314
+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ]
315
+ ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
316
+ # Modular Kernel provisions output buffer from workspace1. However in
317
+ # the fused_marlin_moe() function, the final torch.sum(), is defined
318
+ # essentially as,
319
+ # `torch.sum(workspace1, dim=1, out=output)`
320
+ # Having overlapping input and output tensors for torch.sum seems
321
+ # error prone and depends on how the torch.sum is implemented.
322
+ # For this reason we swap let the output buffer provision from
323
+ # workspace2.
324
+
325
+ # Workspace/IntermediateCache allocation matching fused_marlin_moe()
326
+ #workspace1 = (M * topk * max(2 * N, K),)
327
+ #workspace2 = (M * topk, N)
328
+
329
+ # Workspace/IntermediateCache allocation accounting for output buffer
330
+ # provisioning
331
+ workspace1 = (M * topk , max (N , K ))
332
+ workspace2 = (M * topk * max (2 * N , K ), )
333
+ output = (M , K )
334
+
335
+ return (workspace1 , workspace2 , output , a .dtype )
336
+
337
+ def apply (
338
+ self ,
339
+ output : torch .Tensor ,
340
+ hidden_states : torch .Tensor ,
341
+ w1 : torch .Tensor ,
342
+ w2 : torch .Tensor ,
343
+ topk_weights : torch .Tensor ,
344
+ topk_ids : torch .Tensor ,
345
+ activation : str ,
346
+ global_num_experts : int ,
347
+ expert_map : Optional [torch .Tensor ],
348
+ a1q_scale : Optional [torch .Tensor ],
349
+ a2_scale : Optional [torch .Tensor ],
350
+ workspace13 : torch .Tensor ,
351
+ workspace2 : torch .Tensor ,
352
+ expert_tokens_meta : Optional [mk .ExpertTokensMetadata ],
353
+ apply_router_weight_on_input : bool ,
354
+ ):
355
+ assert self .w1_scale is not None
356
+ assert self .w2_scale is not None
357
+ return fused_marlin_moe (
358
+ hidden_states = hidden_states ,
359
+ w1 = w1 ,
360
+ w2 = w2 ,
361
+ bias1 = self .w1_bias ,
362
+ bias2 = self .w2_bias ,
363
+ w1_scale = self .w1_scale ,
364
+ w2_scale = self .w2_scale ,
365
+ gating_output = None ,
366
+ topk_weights = topk_weights ,
367
+ topk_ids = topk_ids ,
368
+ quant_type_id = scalar_types .float4_e2m1f .id , # works only for w4a16
369
+ apply_router_weight_on_input = apply_router_weight_on_input ,
370
+ global_num_experts = global_num_experts ,
371
+ activation = activation ,
372
+ expert_map = expert_map ,
373
+ output = output ,
374
+ # Workspaces are swapped in workspace_shapes() to account for proper
375
+ # output buffer allocation. Please refer to workspace_shapes().
376
+ intermediate_cache13 = workspace2 ,
377
+ intermediate_cache2 = workspace13 )
0 commit comments