16
16
from vllm .scalar_type import scalar_types
17
17
18
18
19
- FLOAT4_E2M1_MAX = scalar_types .float4_e2m1f .max ()
20
- FLOAT8_E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max
21
- MAX_TOKENS_PER_EXPERT = int (
22
- os .environ .get ('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT' , '65536' ))
23
-
24
-
25
- def cutlass_moe_fp4 (a : torch .Tensor , a1_gscale : torch .Tensor ,
26
- w1_fp4 : torch .Tensor , w1_blockscale : torch .Tensor ,
27
- w1_alphas : torch .Tensor , a2_gscale : torch .Tensor ,
28
- w2_fp4 : torch .Tensor , w2_blockscale : torch .Tensor ,
29
- w2_alphas : torch .Tensor , topk_weights : torch .Tensor ,
30
- topk_ids : torch .Tensor , m : int , n : int , k : int , e : int ,
31
- device : torch .device ):
32
- """
33
- MoE implementation for FP4 Inputs
34
-
35
- # Gemm 1
36
- a: Input tensor: [m, k] (half/bfloat16)
37
- a1_gscale: Activation scale per expert: [e] (float32)
38
- w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
39
- w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
40
- (Note: `n` is the up projection output dim, `k` is the input dim in
41
- full precision)
42
- w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
43
- (Block size = 16 for NVFP4)
44
-
45
- # Gemm 2
46
- a2_gscale: Activation scale per expert: [e]
47
- w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
48
- w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
49
- w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
50
-
51
- topk_weights: [m, topk] dtype: float8
52
- topk_ids: [m, topk] dtype: float8
53
-
54
- m, n, k: Unquantized weight shapes, dtype: int
55
- e: number of experts, dtype: int
56
-
57
- assumes that topk < k < n to satisfy - up/down projection expectations.
58
- """
59
- assert topk_weights .shape == topk_ids .shape , "topk shape mismatch"
60
- assert w1_fp4 .dtype == torch .uint8 , "weight 1 must be uint8"
61
- assert w2_fp4 .dtype == torch .uint8 , "weight 2 must be uint8"
62
- assert (w1_fp4 .ndim == 3 and w2_fp4 .ndim == 3 and w1_blockscale .ndim == 3
63
- and w2_blockscale .ndim
64
- == 3 ), ("All Weights must be of rank 3 for cutlass_moe_fp4" )
65
- m_a , k_a = a .shape
66
- e_w1 , nx2_w1 , half_k_w1 = w1_fp4 .shape
67
- e_w2 , k_w2 , half_n_w2 = w2_fp4 .shape
68
-
69
- assert (e_w1 == e_w2 and e_w1 == e ), ("Number of experts must match" ,
70
- " between weights." )
71
- assert (k_a // 2 == half_k_w1
72
- and k == k_w2 ), ("Hidden size mismatch between a, w1 and w2" )
73
- assert (nx2_w1 == n * 2 and half_n_w2 == n // 2 ), ("mismatch in "
74
- "expected `n`" )
75
- assert (m == m_a ), "input shape mismatch"
76
- assert 2 * half_k_w1 == k_w2 , "Hidden size mismatch w2 and w1"
77
- assert a .dtype in [torch .half , torch .bfloat16 ], "Invalid input dtype"
78
- assert (topk_weights .shape [0 ] == m and topk_ids .shape [0 ]
79
- == m ), ("topk must be provided for each row of a" )
80
- assert (m <= MAX_TOKENS_PER_EXPERT ), (
81
- f"m must be less than MAX_TOKENS_PER_EXPERT({ MAX_TOKENS_PER_EXPERT } )"
82
- f" for cutlass_moe_fp4, observed m = { m } . Use"
83
- f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." )
84
- out_dtype = a .dtype
85
- num_topk = topk_ids .shape [1 ]
86
-
87
- expert_offsets = torch .empty ((e + 1 ), dtype = torch .int32 , device = device )
88
- # Problem size: (num_experts, (m,2n,k))
89
- problem_sizes1 = torch .empty ((e , 3 ), dtype = torch .int32 , device = device )
90
- # Problem size: (num_experts, (m,n,k))
91
- problem_sizes2 = torch .empty ((e , 3 ), dtype = torch .int32 , device = device )
92
-
93
- a_map = torch .empty ((topk_ids .numel ()), dtype = torch .int32 , device = device )
94
- c_map = torch .empty ((topk_ids .numel ()), dtype = torch .int32 , device = device )
95
-
96
- # problem shapes should have [m, n, k]
97
- # Note that problem sizes are based on logical number of elements.
98
- ops .get_cutlass_moe_mm_data (topk_ids , expert_offsets , problem_sizes1 ,
99
- problem_sizes2 , a_map , c_map , e , n , k )
100
-
101
- tokens_per_expert = problem_sizes1 [:, 0 ]
102
- rounded_tokens_per_expert = (tokens_per_expert + (128 - 1 )) // 128 * 128
103
- blockscale_offsets = torch .zeros (e + 1 , dtype = torch .int32 , device = device )
104
- blockscale_offsets [1 :] = torch .cumsum (rounded_tokens_per_expert , dim = 0 )
105
-
106
- rep_a_fp4 , rep_a_blockscale = ops .scaled_fp4_experts_quant (
107
- a ,
108
- a1_gscale ,
109
- expert_offsets ,
110
- blockscale_offsets ,
111
- num_topk ,
112
- expert_map = a_map ,
113
- MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
114
-
115
- c1 = ops .cutlass_fp4_moe_mm (rep_a_fp4 , w1_fp4 , rep_a_blockscale ,
116
- w1_blockscale , w1_alphas , problem_sizes1 ,
117
- expert_offsets [:- 1 ], blockscale_offsets [:- 1 ],
118
- out_dtype , device )
119
- del rep_a_fp4 , rep_a_blockscale
120
- # hidden size dimension is split to one halfpytho sized tensor.
121
- intermediate = torch .empty ((m * num_topk , w1_fp4 .shape [1 ] // 2 ),
122
- device = device ,
123
- dtype = out_dtype )
124
-
125
- torch .ops ._C .silu_and_mul (intermediate , c1 )
126
-
127
- int_fp4 , int_blockscale = ops .scaled_fp4_experts_quant (
128
- intermediate ,
129
- a2_gscale ,
130
- expert_offsets ,
131
- blockscale_offsets ,
132
- num_topk ,
133
- MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
134
-
135
- c2 = ops .cutlass_fp4_moe_mm (int_fp4 , w2_fp4 , int_blockscale , w2_blockscale ,
136
- w2_alphas , problem_sizes2 , expert_offsets [:- 1 ],
137
- blockscale_offsets [:- 1 ], out_dtype , device )
138
- del int_fp4 , int_blockscale
139
- out = (c2 [c_map ].view (m , num_topk , k ) *
140
- topk_weights .view (m , num_topk , 1 ).half ()).sum (dim = 1 )
141
- return out .to (dtype = out_dtype )
142
-
143
-
144
- class CutlassExperts (mk .FusedMoEPermuteExpertsUnpermute ):
19
+ class CutlassExpertsFp8 (mk .FusedMoEPermuteExpertsUnpermute ):
145
20
146
21
def __init__ (
147
22
self ,
@@ -298,7 +173,7 @@ def apply(
298
173
expert_offsets [:- 1 ], problem_sizes2 ,
299
174
self .ab_strides2 , self .ab_strides2 , self .c_strides2 )
300
175
301
- c3 = c3 [c_map , ... ]
176
+ c3 = c3 [c_map ]
302
177
303
178
return c3
304
179
@@ -316,7 +191,7 @@ def modular_cutlass_moe_fp8(
316
191
per_channel_quant = per_act_token ,
317
192
quant_dtype = torch .float8_e4m3fn ,
318
193
),
319
- CutlassExperts (
194
+ CutlassExpertsFp8 (
320
195
ab_strides1 ,
321
196
c_strides1 ,
322
197
ab_strides2 ,
@@ -413,3 +288,128 @@ def cutlass_moe_fp8(
413
288
a2_scale = a2_scale ,
414
289
apply_router_weight_on_input = apply_router_weight_on_input ,
415
290
)
291
+
292
+
293
+ FLOAT4_E2M1_MAX = scalar_types .float4_e2m1f .max ()
294
+ FLOAT8_E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max
295
+ MAX_TOKENS_PER_EXPERT = int (
296
+ os .environ .get ('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT' , '65536' ))
297
+
298
+
299
+ def cutlass_moe_fp4 (a : torch .Tensor , a1_gscale : torch .Tensor ,
300
+ w1_fp4 : torch .Tensor , w1_blockscale : torch .Tensor ,
301
+ w1_alphas : torch .Tensor , a2_gscale : torch .Tensor ,
302
+ w2_fp4 : torch .Tensor , w2_blockscale : torch .Tensor ,
303
+ w2_alphas : torch .Tensor , topk_weights : torch .Tensor ,
304
+ topk_ids : torch .Tensor , m : int , n : int , k : int , e : int ,
305
+ device : torch .device ):
306
+ """
307
+ MoE implementation for FP4 Inputs
308
+
309
+ # Gemm 1
310
+ a: Input tensor: [m, k] (half/bfloat16)
311
+ a1_gscale: Activation scale per expert: [e] (float32)
312
+ w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
313
+ w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
314
+ (Note: `n` is the up projection output dim, `k` is the input dim in
315
+ full precision)
316
+ w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
317
+ (Block size = 16 for NVFP4)
318
+
319
+ # Gemm 2
320
+ a2_gscale: Activation scale per expert: [e]
321
+ w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
322
+ w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
323
+ w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
324
+
325
+ topk_weights: [m, topk] dtype: float8
326
+ topk_ids: [m, topk] dtype: float8
327
+
328
+ m, n, k: Unquantized weight shapes, dtype: int
329
+ e: number of experts, dtype: int
330
+
331
+ assumes that topk < k < n to satisfy - up/down projection expectations.
332
+ """
333
+ assert topk_weights .shape == topk_ids .shape , "topk shape mismatch"
334
+ assert w1_fp4 .dtype == torch .uint8 , "weight 1 must be uint8"
335
+ assert w2_fp4 .dtype == torch .uint8 , "weight 2 must be uint8"
336
+ assert (w1_fp4 .ndim == 3 and w2_fp4 .ndim == 3 and w1_blockscale .ndim == 3
337
+ and w2_blockscale .ndim
338
+ == 3 ), ("All Weights must be of rank 3 for cutlass_moe_fp4" )
339
+ m_a , k_a = a .shape
340
+ e_w1 , nx2_w1 , half_k_w1 = w1_fp4 .shape
341
+ e_w2 , k_w2 , half_n_w2 = w2_fp4 .shape
342
+
343
+ assert (e_w1 == e_w2 and e_w1 == e ), ("Number of experts must match" ,
344
+ " between weights." )
345
+ assert (k_a // 2 == half_k_w1
346
+ and k == k_w2 ), ("Hidden size mismatch between a, w1 and w2" )
347
+ assert (nx2_w1 == n * 2 and half_n_w2 == n // 2 ), ("mismatch in "
348
+ "expected `n`" )
349
+ assert (m == m_a ), "input shape mismatch"
350
+ assert 2 * half_k_w1 == k_w2 , "Hidden size mismatch w2 and w1"
351
+ assert a .dtype in [torch .half , torch .bfloat16 ], "Invalid input dtype"
352
+ assert (topk_weights .shape [0 ] == m and topk_ids .shape [0 ]
353
+ == m ), ("topk must be provided for each row of a" )
354
+ assert (m <= MAX_TOKENS_PER_EXPERT ), (
355
+ f"m must be less than MAX_TOKENS_PER_EXPERT({ MAX_TOKENS_PER_EXPERT } )"
356
+ f" for cutlass_moe_fp4, observed m = { m } . Use"
357
+ f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." )
358
+ out_dtype = a .dtype
359
+ num_topk = topk_ids .shape [1 ]
360
+
361
+ expert_offsets = torch .empty ((e + 1 ), dtype = torch .int32 , device = device )
362
+ # Problem size: (num_experts, (m,2n,k))
363
+ problem_sizes1 = torch .empty ((e , 3 ), dtype = torch .int32 , device = device )
364
+ # Problem size: (num_experts, (m,n,k))
365
+ problem_sizes2 = torch .empty ((e , 3 ), dtype = torch .int32 , device = device )
366
+
367
+ a_map = torch .empty ((topk_ids .numel ()), dtype = torch .int32 , device = device )
368
+ c_map = torch .empty ((topk_ids .numel ()), dtype = torch .int32 , device = device )
369
+
370
+ # problem shapes should have [m, n, k]
371
+ # Note that problem sizes are based on logical number of elements.
372
+ ops .get_cutlass_moe_mm_data (topk_ids , expert_offsets , problem_sizes1 ,
373
+ problem_sizes2 , a_map , c_map , e , n , k )
374
+
375
+ tokens_per_expert = problem_sizes1 [:, 0 ]
376
+ rounded_tokens_per_expert = (tokens_per_expert + (128 - 1 )) // 128 * 128
377
+ blockscale_offsets = torch .zeros (e + 1 , dtype = torch .int32 , device = device )
378
+ blockscale_offsets [1 :] = torch .cumsum (rounded_tokens_per_expert , dim = 0 )
379
+
380
+ rep_a_fp4 , rep_a_blockscale = ops .scaled_fp4_experts_quant (
381
+ a ,
382
+ a1_gscale ,
383
+ expert_offsets ,
384
+ blockscale_offsets ,
385
+ num_topk ,
386
+ expert_map = a_map ,
387
+ MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
388
+
389
+ c1 = ops .cutlass_fp4_moe_mm (rep_a_fp4 , w1_fp4 , rep_a_blockscale ,
390
+ w1_blockscale , w1_alphas , problem_sizes1 ,
391
+ expert_offsets [:- 1 ], blockscale_offsets [:- 1 ],
392
+ out_dtype , device )
393
+ del rep_a_fp4 , rep_a_blockscale
394
+ # hidden size dimension is split to one halfpytho sized tensor.
395
+ intermediate = torch .empty ((m * num_topk , w1_fp4 .shape [1 ] // 2 ),
396
+ device = device ,
397
+ dtype = out_dtype )
398
+
399
+ torch .ops ._C .silu_and_mul (intermediate , c1 )
400
+
401
+ int_fp4 , int_blockscale = ops .scaled_fp4_experts_quant (
402
+ intermediate ,
403
+ a2_gscale ,
404
+ expert_offsets ,
405
+ blockscale_offsets ,
406
+ num_topk ,
407
+ MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
408
+
409
+ c2 = ops .cutlass_fp4_moe_mm (int_fp4 , w2_fp4 , int_blockscale , w2_blockscale ,
410
+ w2_alphas , problem_sizes2 , expert_offsets [:- 1 ],
411
+ blockscale_offsets [:- 1 ], out_dtype , device )
412
+ del int_fp4 , int_blockscale
413
+ out = (c2 [c_map ].view (m , num_topk , k ) *
414
+ topk_weights .view (m , num_topk , 1 ).half ()).sum (dim = 1 )
415
+ return out .to (dtype = out_dtype )
0 commit comments