Skip to content

Commit 089a71d

Browse files
committed
merge
Signed-off-by: Bill Nell <[email protected]>
1 parent 27bee28 commit 089a71d

File tree

8 files changed

+84
-346
lines changed

8 files changed

+84
-346
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 2 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -110,143 +110,6 @@ def test_fused_moe(
110110
rtol=0)
111111

112112

113-
def torch_dispatch(
114-
a: torch.Tensor,
115-
topk_ids: torch.Tensor,
116-
num_experts: int
117-
) -> torch.Tensor:
118-
assert topk_ids.dim() == 2
119-
assert topk_ids.shape[0] == a.shape[0]
120-
121-
num_tokens = a.shape[0]
122-
topk = topk_ids.shape[1]
123-
124-
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
125-
126-
max_num_tokens = tokens_per_expert.max()
127-
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
128-
dtype=a.dtype, device=a.device)
129-
#print(f"b_a shape {b_a.shape}")
130-
131-
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
132-
133-
for token in range(num_tokens):
134-
for j in range(topk):
135-
expert_id = topk_ids[token, j]
136-
idx = token_counts[expert_id]
137-
b_a[expert_id, idx:idx+1, :] = a[token, :]
138-
token_counts[expert_id] = token_counts[expert_id] + 1
139-
140-
return b_a, tokens_per_expert
141-
142-
143-
def torch_combine(b_out, topk_weight, topk_ids):
144-
num_tokens, topk = topk_ids.shape
145-
num_experts = b_out.shape[0]
146-
K = b_out.shape[-1]
147-
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
148-
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
149-
for token in range(num_tokens):
150-
expert_ids = topk_ids[token]
151-
for i in range(expert_ids.numel()):
152-
expert_id = expert_ids[i]
153-
idx = expert_counts[expert_id]
154-
out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i]
155-
expert_counts[expert_id] = expert_counts[expert_id] + 1
156-
157-
return out
158-
159-
160-
def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
161-
num_experts = w1.shape[0]
162-
b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts)
163-
assert b_a.dim() == 3
164-
num_tokens, topk = topk_ids.shape
165-
_, max_num_tokens, K = b_a.shape
166-
assert num_experts == b_a.shape[0] and K == w2.shape[1]
167-
out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device)
168-
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device)
169-
for expert in range(num_experts):
170-
num = tokens_per_expert[expert]
171-
if num > 0:
172-
torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1))
173-
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
174-
175-
return torch_combine(out, topk_weight, topk_ids)
176-
177-
178-
# TODO: same as torch_moe but with fused_topk factored out.
179-
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
180-
M, K = a.shape
181-
topk = topk_ids.shape[1]
182-
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
183-
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
184-
num_experts = w1.shape[0]
185-
for i in range(num_experts):
186-
mask = (topk_ids == i).view(-1)
187-
if mask.sum():
188-
out[mask] = SiluAndMul()(
189-
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
190-
191-
return (out.view(M, -1, w2.shape[1]) *
192-
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
193-
194-
195-
@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
196-
@pytest.mark.parametrize("n", [128, 1024, 2048])
197-
@pytest.mark.parametrize("k", [128, 511, 1024])
198-
@pytest.mark.parametrize("e", NUM_EXPERTS)
199-
@pytest.mark.parametrize("topk", TOP_KS)
200-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
201-
def test_fused_moe_batched_experts(
202-
m: int,
203-
n: int,
204-
k: int,
205-
e: int,
206-
topk: int,
207-
dtype: torch.dtype,
208-
):
209-
current_platform.seed_everything(7)
210-
211-
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
212-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
213-
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
214-
215-
score = torch.randn((m, e), device="cuda", dtype=dtype)
216-
217-
vllm_config = VllmConfig()
218-
with set_current_vllm_config(vllm_config):
219-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
220-
221-
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
222-
223-
if True:
224-
triton_output = torch_batched_moe(a,
225-
w1,
226-
w2,
227-
topk_weight,
228-
topk_ids)
229-
else:
230-
b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
231-
triton_output = fused_batched_experts(
232-
b_a,
233-
w1,
234-
w2,
235-
topk_weight,
236-
topk_ids,
237-
global_num_experts=e
238-
)
239-
240-
if False:
241-
torch.set_printoptions(profile="full")
242-
print("BASELINE")
243-
print(torch_output)
244-
print("OUTPUT")
245-
print(triton_output)
246-
247-
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
248-
249-
250113
@pytest.mark.parametrize("m", [1, 32, 222])
251114
@pytest.mark.parametrize("n", [128, 1024, 2048])
252115
@pytest.mark.parametrize("k", [128, 1024])
@@ -587,7 +450,8 @@ def test_fused_marlin_moe(
587450
topk_weights, topk_ids, token_expert_indices = fused_topk(
588451
a, score, topk, False)
589452

590-
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
453+
with set_current_vllm_config(vllm_config):
454+
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
591455

592456
marlin_output = torch.ops.vllm.fused_marlin_moe(
593457
a,

tests/kernels/moe/test_triton_moe_ptpc_fp8.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from vllm import _custom_ops as ops
10+
from vllm.config import VllmConfig, set_current_vllm_config
1011
from vllm.model_executor.layers.activation import SiluAndMul
1112
from vllm.model_executor.layers.fused_moe import fused_moe
1213
from vllm.platforms import current_platform
@@ -15,6 +16,10 @@
1516
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
1617
allow_module_level=True)
1718

19+
vllm_config = VllmConfig()
20+
vllm_config.scheduler_config.max_num_seqs = 128
21+
vllm_config.scheduler_config.max_model_len = 8192
22+
1823

1924
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
2025
"""Matrix multiplication function that supports per-token input
@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
137142
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
138143
score = torch.randn((M, E), dtype=dtype)
139144

140-
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
141-
out = fused_moe(
142-
a,
143-
w1,
144-
w2,
145-
score,
146-
topk,
147-
renormalize=False,
148-
use_fp8_w8a8=True, # using fp8
149-
per_channel_quant=True,
150-
w1_scale=w1_s,
151-
w2_scale=w2_s,
152-
block_shape=None, # Not using block quantization
153-
)
145+
with set_current_vllm_config(vllm_config):
146+
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
147+
out = fused_moe(
148+
a,
149+
w1,
150+
w2,
151+
score,
152+
topk,
153+
renormalize=False,
154+
use_fp8_w8a8=True, # using fp8
155+
per_channel_quant=True,
156+
w1_scale=w1_s,
157+
w2_scale=w2_s,
158+
block_shape=None, # Not using block quantization
159+
)
154160

155161
# Check results
156162
rel_diff = (torch.mean(

tests/kernels/quantization/test_block_fp8.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
3131
allow_module_level=True)
3232

33+
vllm_config = VllmConfig()
34+
vllm_config.scheduler_config.max_num_seqs = 128
35+
vllm_config.scheduler_config.max_model_len = 8192
36+
3337
# Test configurations
3438
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
3539
NUM_TOKENS = [7, 83, 2048]
@@ -210,10 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
210214
score = torch.randn((M, E), dtype=dtype)
211215

212216
# Set the context to avoid lots of warning spam.
213-
vllm_config = VllmConfig()
214-
vllm_config.scheduler_config.max_num_seqs = 128
215-
vllm_config.scheduler_config.max_model_len = 8192
216-
217217
with set_current_vllm_config(vllm_config):
218218
out = fused_moe(
219219
a,
@@ -261,6 +261,7 @@ def per_block_cast_to_fp8(
261261
@pytest.mark.parametrize(
262262
"M,N,K,block_size,out_dtype,seed",
263263
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
264+
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
264265
@torch.inference_mode()
265266
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
266267
# only aligned sizes
@@ -426,26 +427,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
426427
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
427428
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
428429

429-
if True:
430-
dgm = modular_deep_gemm_fused_moe_fp8()
431-
432-
def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
433-
topk_ids):
434-
return dgm(a,
435-
w1,
436-
w2,
437-
topk_weights,
438-
topk_ids,
439-
w1_scale=w1_s,
440-
w2_scale=w2_s)
441-
else:
442-
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
443-
444430
# Set the context to avoid lots of warning spam.
445-
vllm_config = VllmConfig()
446-
vllm_config.scheduler_config.max_num_seqs = 128
447-
vllm_config.scheduler_config.max_model_len = 8192
448-
449431
with set_current_vllm_config(vllm_config):
450432
if M >= 128:
451433
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
@@ -457,8 +439,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
457439
topk_weights, topk_ids, token_expert_indices = fused_topk(
458440
a, score.float(), topk, False)
459441

460-
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
461-
topk_ids)
442+
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights,
443+
topk_ids)
462444

463445
#print(f"{out.sum()=}")
464446
#print(f"{ref_out.sum()=}")

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def dispatch(
2424
topk_ids: torch.Tensor,
2525
num_experts: int,
2626
expert_map: Optional[torch.Tensor],
27-
apply_router_weight_on_input: bool = False,
27+
apply_router_weight_on_input: bool,
2828
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2929
assert topk_ids.dim() == 2
3030
assert topk_ids.shape[0] == a1.shape[0]
@@ -99,8 +99,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
9999

100100
def __init__(
101101
self,
102-
rank: int = 0,
103-
world_size: int = 1,
104102
max_num_tokens: Optional[int] = None,
105103
use_fp8_w8a8: bool = False,
106104
use_int8_w8a8: bool = False,
@@ -116,8 +114,6 @@ def __init__(
116114
assert block_shape is None
117115
assert block_m is None
118116
self.max_num_tokens = max_num_tokens
119-
self.rank = rank
120-
self.world_size = world_size
121117
assert not use_fp8_w8a8, "NYI"
122118
assert not use_int8_w8a8, "NYI"
123119
assert not use_int8_w8a16, "NYI"
@@ -171,21 +167,14 @@ def apply(
171167
(num_experts, max_num_tokens, w2.shape[1]))
172168
num_local_experts = expert_num_tokens.numel()
173169

174-
# TODO: don't need world_size or rank if expert_base always == 0
175-
#assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}"
176-
#expert_base = rank_chunk(w1.shape[0], self.rank,
177-
# self.world_size) * self.rank
178-
expert_base = 0
179-
180170
for expert in range(num_local_experts):
181171
num = expert_num_tokens[expert]
182172
assert num <= max_num_tokens, f"{num}, {max_num_tokens}"
183173
if num > 0:
184174
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
185175
self.activation(
186176
activation, tmp, hidden_states[expert, :num, :]
187-
@ w1[expert_base + expert].transpose(0, 1))
188-
out[expert, :num, :] = tmp @ w2[expert_base +
189-
expert].transpose(0, 1)
177+
@ w1[expert].transpose(0, 1))
178+
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
190179

191180
return out

0 commit comments

Comments
 (0)