Skip to content

Commit f6acee6

Browse files
committed
test pplx w/naive implementation
Signed-off-by: Bill Nell <[email protected]>
1 parent 0dfd27e commit f6acee6

File tree

4 files changed

+99
-26
lines changed

4 files changed

+99
-26
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ def batch_by_experts(
118118
num_tokens = a.shape[0]
119119
topk = topk_ids.shape[1]
120120

121-
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device)
122-
for i in range(num_tokens):
123-
for j in range(topk):
124-
expert_id = topk_ids[i, j]
125-
tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1
121+
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
126122

127123
max_num_tokens = tokens_per_expert.max()
128124
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
@@ -170,7 +166,6 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
170166
num = tokens_per_expert[expert]
171167
if num > 0:
172168
out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
173-
#out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
174169

175170
out = unbatch_output(out, topk_weight, topk_ids, K)
176171

@@ -231,12 +226,14 @@ def test_fused_moe_batched_experts(
231226
topk_weight,
232227
topk_ids)
233228
else:
234-
triton_output = fused_experts(b_a,
235-
w1,
236-
w2,
237-
topk_weight,
238-
topk_ids,
239-
global_num_experts=e)
229+
triton_output = fused_batched_experts(
230+
b_a,
231+
w1,
232+
w2,
233+
topk_weight,
234+
topk_ids,
235+
global_num_experts=e
236+
)
240237

241238
if False:
242239
torch.set_printoptions(profile="full")

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,72 @@ def apply(
17541754
return intermediate_cache3
17551755

17561756

1757+
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
1758+
1759+
def __init__(
1760+
self,
1761+
use_fp8_w8a8: bool = False,
1762+
use_int8_w8a16: bool = False,
1763+
use_int4_w4a16: bool = False,
1764+
block_shape: Optional[List[int]] = None,
1765+
block_m: Optional[int] = None,
1766+
):
1767+
super().__init__()
1768+
assert not use_fp8_w8a8
1769+
assert not use_int4_w4a16
1770+
assert not use_int8_w8a16
1771+
assert block_shape is None
1772+
assert block_m is None
1773+
1774+
def workspace_shapes(
1775+
self,
1776+
a_dtype: torch.dtype,
1777+
M: int,
1778+
N: int,
1779+
K: int,
1780+
topk: int,
1781+
num_experts: int,
1782+
a: torch.Tensor,
1783+
) -> Tuple[int, int, torch.dtype]:
1784+
max_num_tokens = a.shape[1]
1785+
workspace13 = num_experts * max_num_tokens * K
1786+
workspace2 = M * topk * N * num_experts
1787+
return (workspace13, workspace2, a_dtype)
1788+
1789+
def apply(
1790+
self,
1791+
hidden_states: torch.Tensor,
1792+
w1: torch.Tensor,
1793+
w2: torch.Tensor,
1794+
topk_ids: torch.Tensor,
1795+
activation: str,
1796+
global_num_experts: int,
1797+
expert_map: Optional[torch.Tensor],
1798+
w1_scale: Optional[torch.Tensor],
1799+
w2_scale: Optional[torch.Tensor],
1800+
w1_zp: Optional[torch.Tensor],
1801+
w2_zp: Optional[torch.Tensor],
1802+
a1q_scale: Optional[torch.Tensor],
1803+
a2_scale: Optional[torch.Tensor],
1804+
workspace13: torch.Tensor,
1805+
workspace2: torch.Tensor,
1806+
) -> torch.Tensor:
1807+
from vllm.model_executor.layers.activation import SiluAndMul
1808+
assert hidden_states.dim() == 3
1809+
num_tokens, topk = topk_ids.shape
1810+
_, max_num_tokens, K = hidden_states.shape
1811+
num_experts = w1.shape[0]
1812+
out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1]))
1813+
#tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
1814+
for expert in range(num_experts):
1815+
num = 1 #tokens_per_expert[expert]
1816+
if num > 0:
1817+
#out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
1818+
out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
1819+
1820+
return out
1821+
1822+
17571823
def modular_triton_fused_moe(
17581824
use_fp8_w8a8: bool,
17591825
use_int8_w8a8: bool,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
if current_platform.is_cuda_alike():
3131
from .dispatch_combine import StandardDispatchCombine
32-
from .fused_moe import TritonExperts, fused_experts
32+
from .fused_moe import TritonExperts, BatchedExperts, fused_experts
3333
from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine
3434
from .pplx_dispatch_combine import PplxDispatchCombine
3535
else:
@@ -243,13 +243,16 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine
243243
block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size)
244244
#print(f"block_m = {block_m}")
245245

246-
experts = TritonExperts(
247-
use_fp8_w8a8 = False,
248-
use_int8_w8a16 = False,
249-
use_int4_w4a16 = False,
250-
block_shape = None,
251-
block_m = None, #block_m,
252-
)
246+
if False:
247+
experts = TritonExperts(
248+
use_fp8_w8a8 = False,
249+
use_int8_w8a16 = False,
250+
use_int4_w4a16 = False,
251+
block_shape = None,
252+
block_m = None, #block_m,
253+
)
254+
else:
255+
experts = BatchedExperts()
253256

254257
self.fused_experts = FusedMoEModularKernel(
255258
dispatch_combine,
@@ -1029,7 +1032,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
10291032
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
10301033
router_logits = full_router_logits[chunk_start:chunk_end, :]
10311034

1032-
print(f"loop {chunk_start}:{chunk_end}")
1035+
#print(f"loop {chunk_start}:{chunk_end}")
10331036

10341037
cu_tokens_across_dp_this_iter = torch.cumsum(
10351038
num_tokens_remaining_across_dp.clamp(

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,23 @@ def __init__(
3535
self.allow_deep_gemm = allow_deep_gemm
3636
self.use_fp8_w8a8 = use_fp8_w8a8
3737

38-
def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int,
39-
topk: int,
40-
num_experts: int) -> Tuple[int, int, torch.dtype]:
38+
def workspace_shapes(
39+
self,
40+
a_dtype: torch.dtype,
41+
M: int,
42+
N: int,
43+
K: int,
44+
topk: int,
45+
num_experts: int,
46+
a: torch.Tensor,
47+
) -> Tuple[int, int, torch.dtype]:
4148
# Note: the deep gemm workspaces are strictly larger than the triton
4249
# workspaces so we can be pessimistic here and allocate for DeepGemm
4350
# even if we fall back to triton later, e.g. if expert maps are set.
4451
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
45-
return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts)
52+
return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a)
4653
else:
47-
return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts)
54+
return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a)
4855

4956
def apply(
5057
self,

0 commit comments

Comments
 (0)