Skip to content

Commit 3433b73

Browse files
committed
varun's fixes
Signed-off-by: Bill Nell <[email protected]>
1 parent 9018df8 commit 3433b73

File tree

6 files changed

+824
-39
lines changed

6 files changed

+824
-39
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
import pytest
8+
from dataclasses import dataclass
9+
10+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
11+
invoke_moe_batched_triton_kernel,
12+
invoke_batched_silu_and_mul)
13+
14+
15+
@dataclass
16+
class BatchedMMConfig:
17+
dtype: torch.dtype
18+
num_experts: int
19+
max_tokens_per_expert: int
20+
K: int
21+
N: int
22+
23+
@dataclass
24+
class BatchedMMTensors:
25+
A: torch.Tensor # [E, max_tokens, K]
26+
B: torch.Tensor # [E, K, N] - column major
27+
C: torch.Tensor # [E, max_tokens, N]
28+
num_expert_tokens: torch.Tensor # [E]
29+
30+
@staticmethod
31+
def make_tensors(config: BatchedMMConfig):
32+
A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0
33+
B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0
34+
C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype)
35+
num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32)
36+
return BatchedMMTensors(A,B,C, num_expert_tokens)
37+
38+
39+
def ref_impl(A: torch.Tensor,
40+
B: torch.Tensor,
41+
C: torch.Tensor,
42+
num_expert_tokens: torch.Tensor) -> torch.Tensor:
43+
44+
num_expert_tokens_cpu = num_expert_tokens.clone()
45+
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
46+
num_experts = num_expert_tokens.size(0)
47+
48+
for e in range(num_experts):
49+
num_tokens = num_expert_tokens_cpu[e]
50+
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
51+
52+
53+
return C
54+
55+
@pytest.mark.parametrize("num_experts", [16, 32])
56+
@pytest.mark.parametrize("max_tokens_per_expert", [512])
57+
@pytest.mark.parametrize("K", [256])
58+
@pytest.mark.parametrize("N", [512])
59+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
60+
def test_batched_mm(num_experts: int,
61+
max_tokens_per_expert: int,
62+
K: int,
63+
N: int,
64+
dtype: torch.dtype):
65+
66+
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
67+
tensors = BatchedMMTensors.make_tensors(config)
68+
69+
test_output = tensors.C
70+
ref_output = test_output.clone()
71+
72+
73+
compute_tl_dtype = {torch.float16 : tl.float16,
74+
torch.bfloat16 : tl.bfloat16,
75+
torch.float32 : tl.float32}[test_output.dtype]
76+
invoke_moe_batched_triton_kernel(tensors.A,
77+
tensors.B,
78+
test_output,
79+
tensors.num_expert_tokens,
80+
compute_tl_dtype,
81+
# Quantization data
82+
None,
83+
None,
84+
None,
85+
# Quantization schemes
86+
False,
87+
False,
88+
False,
89+
config = {"BLOCK_SIZE_M": 16,
90+
"BLOCK_SIZE_N": 16,
91+
"BLOCK_SIZE_K": 16})
92+
93+
94+
ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens)
95+
#torch.cuda.synchronize()
96+
#print (f"ref output {ref_output}")
97+
#print (f"test output {test_output}")
98+
99+
torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)
100+
101+
102+
@dataclass
103+
class BatchedSiluMulConfig:
104+
dtype: torch.dtype
105+
num_experts: int
106+
max_tokens_per_expert: int
107+
D: int
108+
109+
@dataclass
110+
class BatchedSiluMulTensors:
111+
input: torch.Tensor
112+
output: torch.Tensor
113+
expert_num_tokens: torch.Tensor
114+
115+
@staticmethod
116+
def make_tensors(config: BatchedSiluMulConfig):
117+
input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0
118+
output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype)
119+
num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32)
120+
return BatchedSiluMulTensors(input, output, num_expert_tokens)
121+
122+
123+
def ref_batched_silu_mul(
124+
output: torch.Tensor,
125+
input: torch.Tensor,
126+
num_expert_tokens: torch.Tensor) -> torch.Tensor:
127+
128+
num_expert_tokens_cpu = num_expert_tokens.clone()
129+
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
130+
num_experts = num_expert_tokens.size(0)
131+
132+
for e in range(num_experts):
133+
num_tokens = num_expert_tokens_cpu[e].item()
134+
out_part = output[e, :num_tokens, :]
135+
in_part = input[e, :num_tokens, :]
136+
torch.ops._C.silu_and_mul(out_part, in_part)
137+
138+
139+
@pytest.mark.parametrize("num_experts", [16, 32])
140+
@pytest.mark.parametrize("max_tokens_per_expert", [128])
141+
@pytest.mark.parametrize("D", [128, 256])
142+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
143+
def test_batched_silu_mul(num_experts: int,
144+
max_tokens_per_expert: int,
145+
D: int,
146+
dtype: torch.dtype):
147+
148+
config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D)
149+
tensors = BatchedSiluMulTensors.make_tensors(config)
150+
151+
test_out = tensors.output
152+
ref_out = torch.zeros_like(test_out)
153+
154+
ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens)
155+
156+
invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens)
157+
158+
torch.testing.assert_close(test_out, ref_out)

vllm/distributed/parallel_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -923,12 +923,12 @@ def pplx_init(rank, world_size):
923923
if world_size > 1:
924924
try:
925925
global PPLX_DID_INIT
926-
print(f"PPLX_INIT {rank} {world_size}")
926+
logger.debug(f"PPLX_INIT {rank} {world_size}")
927927
uid = nvshmem_get_unique_id(
928928
) if rank == 0 else nvshmem_alloc_empty_unique_id()
929929
uid_gpu = uid.cuda()
930930
get_world_group().broadcast(uid_gpu, src=0)
931-
print(f"PPLX_INIT UID={uid_gpu}")
931+
logger.debug(f"PPLX_INIT UID={uid_gpu}")
932932
uid = uid_gpu.to(device='cpu')
933933
nvshmem_init(uid, rank, world_size)
934934
PPLX_DID_INIT = True

0 commit comments

Comments
 (0)