Skip to content

Commit 16092a5

Browse files
committed
batched moe test
Signed-off-by: Bill Nell <[email protected]>
1 parent be24517 commit 16092a5

File tree

2 files changed

+161
-10
lines changed

2 files changed

+161
-10
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
1515
torch_moe_single)
1616
from vllm.config import VllmConfig, set_current_vllm_config
17-
from vllm.model_executor.layers.fused_moe import fused_moe
17+
from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts
1818
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1919
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
2020
fused_moe as iterative_moe)
@@ -25,6 +25,7 @@
2525
from vllm.model_executor.models.mixtral import MixtralMoE
2626
from vllm.platforms import current_platform
2727
from vllm.scalar_type import scalar_types
28+
from vllm.model_executor.layers.activation import SiluAndMul
2829

2930
NUM_EXPERTS = [8, 64]
3031
EP_SIZE = [1, 4]
@@ -106,6 +107,141 @@ def test_fused_moe(
106107
rtol=0)
107108

108109

110+
def batch_by_experts(
111+
a: torch.Tensor,
112+
topk_ids: torch.Tensor,
113+
num_experts: int
114+
) -> torch.Tensor:
115+
#print(topk_ids.shape, topk_ids)
116+
assert topk_ids.dim() == 2
117+
assert topk_ids.shape[0] == a.shape[0]
118+
119+
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device)
120+
for i in range(topk_ids.shape[0]):
121+
for j in range(topk_ids.shape[1]):
122+
expert_id = topk_ids[i, j]
123+
tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1
124+
125+
#print(f"token_per_expert {tokens_per_expert.max()}")
126+
max_num_tokens = tokens_per_expert.max()
127+
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
128+
dtype=a.dtype, device=a.device)
129+
#print(f"b_a shape {b_a.shape}")
130+
131+
#experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device)
132+
133+
for i in range(topk_ids.shape[0]):
134+
for j in range(topk_ids.shape[1]):
135+
expert_id = topk_ids[i, j]
136+
#idx = experts_per_token[i]
137+
b_a[expert_id, j:j+1, :] = a[i, :]
138+
#experts_per_token[i] = experts_per_token[i] + 1
139+
140+
return b_a, tokens_per_expert
141+
142+
143+
def unbatch_output(b_out, topk_ids, K):
144+
num_tokens, topk = topk_ids.shape
145+
146+
#print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}")
147+
num_experts = b_out.shape[0]
148+
out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device)
149+
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
150+
for token in range(num_tokens):
151+
expert_ids = topk_ids[token]
152+
#print(f"b_out[0] = {b_out[0].shape}")
153+
for i in range(expert_ids.numel()):
154+
expert_id = expert_ids[i]
155+
idx = expert_counts[expert_id]
156+
out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :]
157+
idx = idx + 1
158+
expert_counts[expert_id] = idx
159+
160+
return out
161+
162+
163+
def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
164+
assert a.dim() == 3
165+
#print(f"A = {a.shape} {a[0, :, :].shape}")
166+
num_tokens, topk = topk_ids.shape
167+
_, max_num_tokens, K = a.shape
168+
num_experts = w1.shape[0]
169+
out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device)
170+
for expert in range(num_experts):
171+
num = tokens_per_expert[expert]
172+
if num > 0:
173+
#out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
174+
out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
175+
176+
out = unbatch_output(out, topk_ids, w2.shape[1])
177+
178+
return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
179+
180+
181+
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
182+
M, K = a.shape
183+
topk = topk_ids.shape[1]
184+
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
185+
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
186+
num_experts = w1.shape[0]
187+
for i in range(num_experts):
188+
mask = (topk_ids == i).view(-1)
189+
if mask.sum():
190+
out[mask] = SiluAndMul()(
191+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
192+
193+
return (out.view(M, -1, w2.shape[1]) *
194+
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
195+
196+
197+
@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
198+
@pytest.mark.parametrize("n", [128, 1024, 2048])
199+
@pytest.mark.parametrize("k", [128, 511, 1024])
200+
@pytest.mark.parametrize("e", NUM_EXPERTS)
201+
@pytest.mark.parametrize("topk", TOP_KS)
202+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
203+
def test_fused_moe_batched_experts(
204+
m: int,
205+
n: int,
206+
k: int,
207+
e: int,
208+
topk: int,
209+
dtype: torch.dtype,
210+
):
211+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
212+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
213+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
214+
215+
score = torch.randn((m, e), device="cuda", dtype=dtype)
216+
e_map = None
217+
218+
vllm_config = VllmConfig()
219+
with set_current_vllm_config(vllm_config):
220+
topk_weight, topk_ids = fused_topk(a, score, topk, False)
221+
222+
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
223+
224+
b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
225+
226+
if True:
227+
triton_output = torch_batched_moe(b_a,
228+
w1,
229+
w2,
230+
tokens_per_expert,
231+
topk_weight,
232+
topk_ids)
233+
else:
234+
triton_output = fused_experts(a, # b_a
235+
w1,
236+
w2,
237+
topk_weight,
238+
topk_ids,
239+
global_num_experts=e)
240+
241+
#torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0)
242+
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
243+
244+
109245
@pytest.mark.parametrize("m", [1, 32, 222])
110246
@pytest.mark.parametrize("n", [128, 1024, 2048])
111247
@pytest.mark.parametrize("k", [128, 1024])

vllm/distributed/parallel_state.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -915,16 +915,31 @@ def init_distributed_environment(
915915
"world group already initialized with a different world size")
916916

917917

918+
PPLX_DID_INIT: bool = False
919+
918920
@run_once
919921
def pplx_init(rank, world_size):
920-
print(f"PPLX_INIT {rank} {world_size}")
921-
uid = nvshmem_get_unique_id(
922-
) if rank == 0 else nvshmem_alloc_empty_unique_id()
923-
uid_gpu = uid.cuda()
924-
get_world_group().broadcast(uid_gpu, src=0)
925-
print(f"PPLX_INIT UID={uid_gpu}")
926-
uid = uid_gpu.to(device='cpu')
927-
nvshmem_init(uid, rank, world_size)
922+
if world_size > 1:
923+
try:
924+
global PPLX_DID_INIT
925+
print(f"PPLX_INIT {rank} {world_size}")
926+
uid = nvshmem_get_unique_id(
927+
) if rank == 0 else nvshmem_alloc_empty_unique_id()
928+
uid_gpu = uid.cuda()
929+
get_world_group().broadcast(uid_gpu, src=0)
930+
print(f"PPLX_INIT UID={uid_gpu}")
931+
uid = uid_gpu.to(device='cpu')
932+
nvshmem_init(uid, rank, world_size)
933+
PPLX_DID_INIT = True
934+
except Exception as ex:
935+
logger.error("Failed to initialize nvshmem for pplx: %s", ex)
936+
937+
938+
@run_once
939+
def pplx_finalize():
940+
global PPLX_DID_INIT
941+
if PPLX_DID_INIT:
942+
nvshmem_finalize()
928943

929944

930945
def initialize_model_parallel(
@@ -1099,7 +1114,7 @@ def destroy_model_parallel():
10991114
"""Set the groups to none and destroy them."""
11001115
global _TP
11011116

1102-
nvshmem_finalize()
1117+
pplx_finalize()
11031118

11041119
if _TP:
11051120
_TP.destroy()

0 commit comments

Comments
 (0)