Skip to content

Commit ca2ff26

Browse files
committed
fix merge + add comments
Signed-off-by: Bill Nell <[email protected]>
1 parent 43e229c commit ca2ff26

File tree

2 files changed

+143
-131
lines changed

2 files changed

+143
-131
lines changed

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 128 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,132 +16,7 @@
1616
from vllm.scalar_type import scalar_types
1717

1818

19-
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
20-
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
21-
MAX_TOKENS_PER_EXPERT = int(
22-
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
23-
24-
25-
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
26-
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
27-
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
28-
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
29-
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
30-
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
31-
device: torch.device):
32-
"""
33-
MoE implementation for FP4 Inputs
34-
35-
# Gemm 1
36-
a: Input tensor: [m, k] (half/bfloat16)
37-
a1_gscale: Activation scale per expert: [e] (float32)
38-
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
39-
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
40-
(Note: `n` is the up projection output dim, `k` is the input dim in
41-
full precision)
42-
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
43-
(Block size = 16 for NVFP4)
44-
45-
# Gemm 2
46-
a2_gscale: Activation scale per expert: [e]
47-
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
48-
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
49-
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
50-
51-
topk_weights: [m, topk] dtype: float8
52-
topk_ids: [m, topk] dtype: float8
53-
54-
m, n, k: Unquantized weight shapes, dtype: int
55-
e: number of experts, dtype: int
56-
57-
assumes that topk < k < n to satisfy - up/down projection expectations.
58-
"""
59-
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
60-
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
61-
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
62-
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
63-
and w2_blockscale.ndim
64-
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
65-
m_a, k_a = a.shape
66-
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
67-
e_w2, k_w2, half_n_w2 = w2_fp4.shape
68-
69-
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
70-
" between weights.")
71-
assert (k_a // 2 == half_k_w1
72-
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
73-
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
74-
"expected `n`")
75-
assert (m == m_a), "input shape mismatch"
76-
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
77-
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
78-
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
79-
== m), ("topk must be provided for each row of a")
80-
assert (m <= MAX_TOKENS_PER_EXPERT), (
81-
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
82-
f" for cutlass_moe_fp4, observed m = {m}. Use"
83-
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
84-
out_dtype = a.dtype
85-
num_topk = topk_ids.shape[1]
86-
87-
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
88-
# Problem size: (num_experts, (m,2n,k))
89-
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
90-
# Problem size: (num_experts, (m,n,k))
91-
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
92-
93-
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
94-
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
95-
96-
# problem shapes should have [m, n, k]
97-
# Note that problem sizes are based on logical number of elements.
98-
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
99-
problem_sizes2, a_map, c_map, e, n, k)
100-
101-
tokens_per_expert = problem_sizes1[:, 0]
102-
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
103-
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
104-
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
105-
106-
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
107-
a,
108-
a1_gscale,
109-
expert_offsets,
110-
blockscale_offsets,
111-
num_topk,
112-
expert_map=a_map,
113-
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
114-
115-
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
116-
w1_blockscale, w1_alphas, problem_sizes1,
117-
expert_offsets[:-1], blockscale_offsets[:-1],
118-
out_dtype, device)
119-
del rep_a_fp4, rep_a_blockscale
120-
# hidden size dimension is split to one halfpytho sized tensor.
121-
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
122-
device=device,
123-
dtype=out_dtype)
124-
125-
torch.ops._C.silu_and_mul(intermediate, c1)
126-
127-
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
128-
intermediate,
129-
a2_gscale,
130-
expert_offsets,
131-
blockscale_offsets,
132-
num_topk,
133-
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
134-
135-
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
136-
w2_alphas, problem_sizes2, expert_offsets[:-1],
137-
blockscale_offsets[:-1], out_dtype, device)
138-
del int_fp4, int_blockscale
139-
out = (c2[c_map].view(m, num_topk, k) *
140-
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
141-
return out.to(dtype=out_dtype)
142-
143-
144-
class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute):
19+
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
14520

14621
def __init__(
14722
self,
@@ -298,7 +173,7 @@ def apply(
298173
expert_offsets[:-1], problem_sizes2,
299174
self.ab_strides2, self.ab_strides2, self.c_strides2)
300175

301-
c3 = c3[c_map, ...]
176+
c3 = c3[c_map]
302177

303178
return c3
304179

@@ -316,7 +191,7 @@ def modular_cutlass_moe_fp8(
316191
per_channel_quant=per_act_token,
317192
quant_dtype=torch.float8_e4m3fn,
318193
),
319-
CutlassExperts(
194+
CutlassExpertsFp8(
320195
ab_strides1,
321196
c_strides1,
322197
ab_strides2,
@@ -413,3 +288,128 @@ def cutlass_moe_fp8(
413288
a2_scale=a2_scale,
414289
apply_router_weight_on_input=apply_router_weight_on_input,
415290
)
291+
292+
293+
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
294+
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
295+
MAX_TOKENS_PER_EXPERT = int(
296+
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
297+
298+
299+
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
300+
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
301+
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
302+
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
303+
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
304+
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
305+
device: torch.device):
306+
"""
307+
MoE implementation for FP4 Inputs
308+
309+
# Gemm 1
310+
a: Input tensor: [m, k] (half/bfloat16)
311+
a1_gscale: Activation scale per expert: [e] (float32)
312+
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
313+
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
314+
(Note: `n` is the up projection output dim, `k` is the input dim in
315+
full precision)
316+
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
317+
(Block size = 16 for NVFP4)
318+
319+
# Gemm 2
320+
a2_gscale: Activation scale per expert: [e]
321+
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
322+
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
323+
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
324+
325+
topk_weights: [m, topk] dtype: float8
326+
topk_ids: [m, topk] dtype: float8
327+
328+
m, n, k: Unquantized weight shapes, dtype: int
329+
e: number of experts, dtype: int
330+
331+
assumes that topk < k < n to satisfy - up/down projection expectations.
332+
"""
333+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
334+
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
335+
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
336+
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
337+
and w2_blockscale.ndim
338+
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
339+
m_a, k_a = a.shape
340+
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
341+
e_w2, k_w2, half_n_w2 = w2_fp4.shape
342+
343+
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
344+
" between weights.")
345+
assert (k_a // 2 == half_k_w1
346+
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
347+
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
348+
"expected `n`")
349+
assert (m == m_a), "input shape mismatch"
350+
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
351+
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
352+
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
353+
== m), ("topk must be provided for each row of a")
354+
assert (m <= MAX_TOKENS_PER_EXPERT), (
355+
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
356+
f" for cutlass_moe_fp4, observed m = {m}. Use"
357+
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
358+
out_dtype = a.dtype
359+
num_topk = topk_ids.shape[1]
360+
361+
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
362+
# Problem size: (num_experts, (m,2n,k))
363+
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
364+
# Problem size: (num_experts, (m,n,k))
365+
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
366+
367+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
368+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
369+
370+
# problem shapes should have [m, n, k]
371+
# Note that problem sizes are based on logical number of elements.
372+
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
373+
problem_sizes2, a_map, c_map, e, n, k)
374+
375+
tokens_per_expert = problem_sizes1[:, 0]
376+
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
377+
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
378+
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
379+
380+
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
381+
a,
382+
a1_gscale,
383+
expert_offsets,
384+
blockscale_offsets,
385+
num_topk,
386+
expert_map=a_map,
387+
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
388+
389+
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
390+
w1_blockscale, w1_alphas, problem_sizes1,
391+
expert_offsets[:-1], blockscale_offsets[:-1],
392+
out_dtype, device)
393+
del rep_a_fp4, rep_a_blockscale
394+
# hidden size dimension is split to one halfpytho sized tensor.
395+
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
396+
device=device,
397+
dtype=out_dtype)
398+
399+
torch.ops._C.silu_and_mul(intermediate, c1)
400+
401+
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
402+
intermediate,
403+
a2_gscale,
404+
expert_offsets,
405+
blockscale_offsets,
406+
num_topk,
407+
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
408+
409+
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
410+
w2_alphas, problem_sizes2, expert_offsets[:-1],
411+
blockscale_offsets[:-1], out_dtype, device)
412+
del int_fp4, int_blockscale
413+
out = (c2[c_map].view(m, num_topk, k) *
414+
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
415+
return out.to(dtype=out_dtype)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,11 @@ def rank_chunk(num, r, w):
385385

386386

387387
class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
388-
388+
"""
389+
A reference dispatch/combine class that reorganizes the tokens into
390+
expert batched format, i.e. E x max_num_tokens x K. This is the format
391+
that the PPLX dispatch/combine kernels use.
392+
"""
389393
def __init__(self, max_num_tokens: Optional[int], world_size: int,
390394
dp_size: int, rank: int):
391395
super().__init__()
@@ -478,7 +482,11 @@ def combine(
478482

479483

480484
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
481-
485+
"""
486+
A reference MoE expert class that operates on expert batched format,
487+
i.e. E x max_num_tokens x K. This is the format that the pplx
488+
dispatch/combine kernels use.
489+
"""
482490
def __init__(
483491
self,
484492
world_size: int,
@@ -580,7 +588,11 @@ def apply(
580588

581589

582590
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
583-
591+
"""
592+
A Triton based MoE expert class that operates on expert batched format,
593+
i.e. E x max_num_tokens x K. This is the format that the pplx
594+
dispatch/combine kernels use.
595+
"""
584596
def __init__(
585597
self,
586598
max_num_tokens: Optional[int] = None,

0 commit comments

Comments
 (0)