13
13
MoEPrepareAndFinalizeNoEP )
14
14
from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
15
15
TopKWeightAndReduceDelegate )
16
- from vllm .model_executor .layers .fused_moe .utils import (_fp8_quantize ,
16
+ from vllm .model_executor .layers .fused_moe .utils import (_fp8_perm ,
17
+ _fp8_quantize ,
17
18
_resize_cache ,
18
19
extract_required_args )
19
20
from vllm .scalar_type import scalar_types
@@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
34
35
w2_scale : Optional [torch .Tensor ],
35
36
a1q_scale : Optional [torch .Tensor ],
36
37
a2_scale : Optional [torch .Tensor ],
37
- ab_strides1 : torch .Tensor ,
38
- ab_strides2 : torch .Tensor ,
39
- c_strides1 : torch .Tensor ,
40
- c_strides2 : torch .Tensor ,
41
38
workspace13 : torch .Tensor ,
42
39
workspace2 : torch .Tensor ,
43
40
expert_num_tokens : Optional [torch .Tensor ],
@@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
156
153
problem_sizes1 , problem_sizes2 , a_map ,
157
154
c_map , global_num_experts , N , K )
158
155
159
- a1q = ops .shuffle_rows (a1q , a_map )
160
- a1q_scale = (ops .shuffle_rows (a1q_scale , a_map )
161
- if per_act_token else a1q_scale )
156
+ a1q = _fp8_perm (a1q , a_map )
157
+ a1q_scale = a1q_scale [a_map ] if per_act_token else a1q_scale
162
158
expert_offsets = expert_offsets [:- 1 ]
163
159
160
+ ab_strides1 = torch .full ((w1 .size (0 ), ),
161
+ K ,
162
+ device = device ,
163
+ dtype = torch .int64 )
164
+ c_strides1 = torch .full ((w1 .size (0 ), ),
165
+ 2 * N ,
166
+ device = device ,
167
+ dtype = torch .int64 )
168
+ ab_strides2 = torch .full ((w1 .size (0 ), ),
169
+ N ,
170
+ device = device ,
171
+ dtype = torch .int64 )
172
+ c_strides2 = torch .full ((w1 .size (0 ), ),
173
+ K ,
174
+ device = device ,
175
+ dtype = torch .int64 )
176
+
164
177
if use_batched_format :
165
178
c1 = _resize_cache (workspace13 , (local_E * padded_M , N * 2 ))
166
179
c2 = _resize_cache (workspace2 , (local_E * padded_M , N ))
@@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
197
210
else :
198
211
# We can't do this inplace because output may point to the same tensor
199
212
# as c3.
200
- output .copy_ (ops .shuffle_rows (c3 , c_map ).view (M * topk , K ),
201
- non_blocking = True )
213
+ output .copy_ (c3 [c_map ].view (M * topk , K ), non_blocking = True )
202
214
203
215
204
216
# TODO (bnell): split class batched vs. non-batched?
@@ -211,10 +223,6 @@ def __init__(
211
223
out_dtype : Optional [torch .dtype ],
212
224
per_act_token_quant : bool ,
213
225
per_out_ch_quant : bool ,
214
- ab_strides1 : torch .Tensor ,
215
- ab_strides2 : torch .Tensor ,
216
- c_strides1 : torch .Tensor ,
217
- c_strides2 : torch .Tensor ,
218
226
block_shape : Optional [list [int ]] = None ,
219
227
num_dispatchers : Optional [int ] = None ,
220
228
use_batched_format : bool = False ,
@@ -231,10 +239,6 @@ def __init__(
231
239
self .max_experts_per_worker = max_experts_per_worker
232
240
self .num_dispatchers = num_dispatchers
233
241
self .out_dtype = out_dtype
234
- self .ab_strides1 = ab_strides1
235
- self .ab_strides2 = ab_strides2
236
- self .c_strides1 = c_strides1
237
- self .c_strides2 = c_strides2
238
242
self .use_batched_format = use_batched_format
239
243
240
244
@property
@@ -314,8 +318,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
314
318
run_cutlass_moe_fp8 (
315
319
output , hidden_states , w1 , w2 , topk_ids , activation_callable ,
316
320
global_num_experts , expert_map , w1_scale , w2_scale , a1q_scale ,
317
- a2_scale , self .ab_strides1 , self .ab_strides2 , self .c_strides1 ,
318
- self .c_strides2 , workspace13 , workspace2 , expert_num_tokens ,
321
+ a2_scale , workspace13 , workspace2 , expert_num_tokens ,
319
322
self .out_dtype if self .out_dtype is not None else in_dtype ,
320
323
self .per_act_token_quant , self .per_out_ch_quant ,
321
324
self .use_batched_format )
@@ -329,10 +332,6 @@ def cutlass_moe_fp8(
329
332
topk_ids : torch .Tensor ,
330
333
w1_scale : torch .Tensor ,
331
334
w2_scale : torch .Tensor ,
332
- ab_strides1 : torch .Tensor ,
333
- ab_strides2 : torch .Tensor ,
334
- c_strides1 : torch .Tensor ,
335
- c_strides2 : torch .Tensor ,
336
335
per_act_token : Optional [bool ] = None ,
337
336
activation : str = "silu" ,
338
337
a1_scale : Optional [torch .Tensor ] = None ,
@@ -360,17 +359,6 @@ def cutlass_moe_fp8(
360
359
Shape: [num_experts] or [num_experts, 2N]
361
360
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
362
361
Shape: [num_experts] or [num_experts, K]
363
- - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
364
- Shape: [num_experts]
365
- - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
366
- Shape: [num_experts]
367
- - c_strides1 (torch.Tensor): The output strides for the first gemm.
368
- Shape: [num_experts]
369
- - c_strides2 (torch.Tensor): The output strides for the second gemm.
370
- Shape: [num_experts]
371
- - per_act_token (Optional[bool]): Whether the scale is per-token or
372
- per-tensor.
373
- - activation (str): The activation function to use.
374
362
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
375
363
Shape: scalar or [M]
376
364
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -403,10 +391,6 @@ def cutlass_moe_fp8(
403
391
out_dtype = a .dtype ,
404
392
per_act_token_quant = per_act_token ,
405
393
per_out_ch_quant = per_out_ch ,
406
- ab_strides1 = ab_strides1 ,
407
- ab_strides2 = ab_strides2 ,
408
- c_strides1 = c_strides1 ,
409
- c_strides2 = c_strides2 ,
410
394
use_batched_format = False ,
411
395
),
412
396
)
0 commit comments