|
14 | 14 | from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
|
15 | 15 | torch_moe_single)
|
16 | 16 | from vllm.config import VllmConfig, set_current_vllm_config
|
17 |
| -from vllm.model_executor.layers.fused_moe import fused_moe |
| 17 | +from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts |
18 | 18 | from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
19 | 19 | from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
20 | 20 | fused_moe as iterative_moe)
|
|
25 | 25 | from vllm.model_executor.models.mixtral import MixtralMoE
|
26 | 26 | from vllm.platforms import current_platform
|
27 | 27 | from vllm.scalar_type import scalar_types
|
| 28 | +from vllm.model_executor.layers.activation import SiluAndMul |
28 | 29 |
|
29 | 30 | NUM_EXPERTS = [8, 64]
|
30 | 31 | EP_SIZE = [1, 4]
|
@@ -106,6 +107,141 @@ def test_fused_moe(
|
106 | 107 | rtol=0)
|
107 | 108 |
|
108 | 109 |
|
| 110 | +def batch_by_experts( |
| 111 | + a: torch.Tensor, |
| 112 | + topk_ids: torch.Tensor, |
| 113 | + num_experts: int |
| 114 | +) -> torch.Tensor: |
| 115 | + #print(topk_ids.shape, topk_ids) |
| 116 | + assert topk_ids.dim() == 2 |
| 117 | + assert topk_ids.shape[0] == a.shape[0] |
| 118 | + |
| 119 | + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) |
| 120 | + for i in range(topk_ids.shape[0]): |
| 121 | + for j in range(topk_ids.shape[1]): |
| 122 | + expert_id = topk_ids[i, j] |
| 123 | + tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 |
| 124 | + |
| 125 | + #print(f"token_per_expert {tokens_per_expert.max()}") |
| 126 | + max_num_tokens = tokens_per_expert.max() |
| 127 | + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), |
| 128 | + dtype=a.dtype, device=a.device) |
| 129 | + #print(f"b_a shape {b_a.shape}") |
| 130 | + |
| 131 | + #experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device) |
| 132 | + |
| 133 | + for i in range(topk_ids.shape[0]): |
| 134 | + for j in range(topk_ids.shape[1]): |
| 135 | + expert_id = topk_ids[i, j] |
| 136 | + #idx = experts_per_token[i] |
| 137 | + b_a[expert_id, j:j+1, :] = a[i, :] |
| 138 | + #experts_per_token[i] = experts_per_token[i] + 1 |
| 139 | + |
| 140 | + return b_a, tokens_per_expert |
| 141 | + |
| 142 | + |
| 143 | +def unbatch_output(b_out, topk_ids, K): |
| 144 | + num_tokens, topk = topk_ids.shape |
| 145 | + |
| 146 | + #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") |
| 147 | + num_experts = b_out.shape[0] |
| 148 | + out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device) |
| 149 | + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) |
| 150 | + for token in range(num_tokens): |
| 151 | + expert_ids = topk_ids[token] |
| 152 | + #print(f"b_out[0] = {b_out[0].shape}") |
| 153 | + for i in range(expert_ids.numel()): |
| 154 | + expert_id = expert_ids[i] |
| 155 | + idx = expert_counts[expert_id] |
| 156 | + out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :] |
| 157 | + idx = idx + 1 |
| 158 | + expert_counts[expert_id] = idx |
| 159 | + |
| 160 | + return out |
| 161 | + |
| 162 | + |
| 163 | +def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): |
| 164 | + assert a.dim() == 3 |
| 165 | + #print(f"A = {a.shape} {a[0, :, :].shape}") |
| 166 | + num_tokens, topk = topk_ids.shape |
| 167 | + _, max_num_tokens, K = a.shape |
| 168 | + num_experts = w1.shape[0] |
| 169 | + out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device) |
| 170 | + for expert in range(num_experts): |
| 171 | + num = tokens_per_expert[expert] |
| 172 | + if num > 0: |
| 173 | + #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) |
| 174 | + out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) |
| 175 | + |
| 176 | + out = unbatch_output(out, topk_ids, w2.shape[1]) |
| 177 | + |
| 178 | + return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) |
| 179 | + |
| 180 | + |
| 181 | +def torch_moe2(a, w1, w2, topk_weight, topk_ids): |
| 182 | + M, K = a.shape |
| 183 | + topk = topk_ids.shape[1] |
| 184 | + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) |
| 185 | + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
| 186 | + num_experts = w1.shape[0] |
| 187 | + for i in range(num_experts): |
| 188 | + mask = (topk_ids == i).view(-1) |
| 189 | + if mask.sum(): |
| 190 | + out[mask] = SiluAndMul()( |
| 191 | + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) |
| 192 | + |
| 193 | + return (out.view(M, -1, w2.shape[1]) * |
| 194 | + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) |
| 195 | + |
| 196 | + |
| 197 | +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) |
| 198 | +@pytest.mark.parametrize("n", [128, 1024, 2048]) |
| 199 | +@pytest.mark.parametrize("k", [128, 511, 1024]) |
| 200 | +@pytest.mark.parametrize("e", NUM_EXPERTS) |
| 201 | +@pytest.mark.parametrize("topk", TOP_KS) |
| 202 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 203 | +def test_fused_moe_batched_experts( |
| 204 | + m: int, |
| 205 | + n: int, |
| 206 | + k: int, |
| 207 | + e: int, |
| 208 | + topk: int, |
| 209 | + dtype: torch.dtype, |
| 210 | +): |
| 211 | + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 |
| 212 | + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 |
| 213 | + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 |
| 214 | + |
| 215 | + score = torch.randn((m, e), device="cuda", dtype=dtype) |
| 216 | + e_map = None |
| 217 | + |
| 218 | + vllm_config = VllmConfig() |
| 219 | + with set_current_vllm_config(vllm_config): |
| 220 | + topk_weight, topk_ids = fused_topk(a, score, topk, False) |
| 221 | + |
| 222 | + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) |
| 223 | + |
| 224 | + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) |
| 225 | + |
| 226 | + if True: |
| 227 | + triton_output = torch_batched_moe(b_a, |
| 228 | + w1, |
| 229 | + w2, |
| 230 | + tokens_per_expert, |
| 231 | + topk_weight, |
| 232 | + topk_ids) |
| 233 | + else: |
| 234 | + triton_output = fused_experts(a, # b_a |
| 235 | + w1, |
| 236 | + w2, |
| 237 | + topk_weight, |
| 238 | + topk_ids, |
| 239 | + global_num_experts=e) |
| 240 | + |
| 241 | + #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) |
| 242 | + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) |
| 243 | + |
| 244 | + |
109 | 245 | @pytest.mark.parametrize("m", [1, 32, 222])
|
110 | 246 | @pytest.mark.parametrize("n", [128, 1024, 2048])
|
111 | 247 | @pytest.mark.parametrize("k", [128, 1024])
|
|
0 commit comments