|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | +import pytest |
| 8 | +from dataclasses import dataclass |
| 9 | + |
| 10 | +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( |
| 11 | + invoke_moe_batched_triton_kernel, |
| 12 | + invoke_batched_silu_and_mul) |
| 13 | + |
| 14 | + |
| 15 | +@dataclass |
| 16 | +class BatchedMMConfig: |
| 17 | + dtype: torch.dtype |
| 18 | + num_experts: int |
| 19 | + max_tokens_per_expert: int |
| 20 | + K: int |
| 21 | + N: int |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class BatchedMMTensors: |
| 25 | + A: torch.Tensor # [E, max_tokens, K] |
| 26 | + B: torch.Tensor # [E, K, N] - column major |
| 27 | + C: torch.Tensor # [E, max_tokens, N] |
| 28 | + num_expert_tokens: torch.Tensor # [E] |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def make_tensors(config: BatchedMMConfig): |
| 32 | + A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0 |
| 33 | + B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0 |
| 34 | + C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype) |
| 35 | + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) |
| 36 | + return BatchedMMTensors(A,B,C, num_expert_tokens) |
| 37 | + |
| 38 | + |
| 39 | +def ref_impl(A: torch.Tensor, |
| 40 | + B: torch.Tensor, |
| 41 | + C: torch.Tensor, |
| 42 | + num_expert_tokens: torch.Tensor) -> torch.Tensor: |
| 43 | + |
| 44 | + num_expert_tokens_cpu = num_expert_tokens.clone() |
| 45 | + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") |
| 46 | + num_experts = num_expert_tokens.size(0) |
| 47 | + |
| 48 | + for e in range(num_experts): |
| 49 | + num_tokens = num_expert_tokens_cpu[e] |
| 50 | + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) |
| 51 | + |
| 52 | + |
| 53 | + return C |
| 54 | + |
| 55 | +@pytest.mark.parametrize("num_experts", [16, 32]) |
| 56 | +@pytest.mark.parametrize("max_tokens_per_expert", [512]) |
| 57 | +@pytest.mark.parametrize("K", [256]) |
| 58 | +@pytest.mark.parametrize("N", [512]) |
| 59 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 60 | +def test_batched_mm(num_experts: int, |
| 61 | + max_tokens_per_expert: int, |
| 62 | + K: int, |
| 63 | + N: int, |
| 64 | + dtype: torch.dtype): |
| 65 | + |
| 66 | + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) |
| 67 | + tensors = BatchedMMTensors.make_tensors(config) |
| 68 | + |
| 69 | + test_output = tensors.C |
| 70 | + ref_output = test_output.clone() |
| 71 | + |
| 72 | + |
| 73 | + compute_tl_dtype = {torch.float16 : tl.float16, |
| 74 | + torch.bfloat16 : tl.bfloat16, |
| 75 | + torch.float32 : tl.float32}[test_output.dtype] |
| 76 | + invoke_moe_batched_triton_kernel(tensors.A, |
| 77 | + tensors.B, |
| 78 | + test_output, |
| 79 | + tensors.num_expert_tokens, |
| 80 | + compute_tl_dtype, |
| 81 | + # Quantization data |
| 82 | + None, |
| 83 | + None, |
| 84 | + None, |
| 85 | + # Quantization schemes |
| 86 | + False, |
| 87 | + False, |
| 88 | + False, |
| 89 | + config = {"BLOCK_SIZE_M": 16, |
| 90 | + "BLOCK_SIZE_N": 16, |
| 91 | + "BLOCK_SIZE_K": 16}) |
| 92 | + |
| 93 | + |
| 94 | + ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) |
| 95 | + #torch.cuda.synchronize() |
| 96 | + #print (f"ref output {ref_output}") |
| 97 | + #print (f"test output {test_output}") |
| 98 | + |
| 99 | + torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) |
| 100 | + |
| 101 | + |
| 102 | +@dataclass |
| 103 | +class BatchedSiluMulConfig: |
| 104 | + dtype: torch.dtype |
| 105 | + num_experts: int |
| 106 | + max_tokens_per_expert: int |
| 107 | + D: int |
| 108 | + |
| 109 | +@dataclass |
| 110 | +class BatchedSiluMulTensors: |
| 111 | + input: torch.Tensor |
| 112 | + output: torch.Tensor |
| 113 | + expert_num_tokens: torch.Tensor |
| 114 | + |
| 115 | + @staticmethod |
| 116 | + def make_tensors(config: BatchedSiluMulConfig): |
| 117 | + input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0 |
| 118 | + output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype) |
| 119 | + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) |
| 120 | + return BatchedSiluMulTensors(input, output, num_expert_tokens) |
| 121 | + |
| 122 | + |
| 123 | +def ref_batched_silu_mul( |
| 124 | + output: torch.Tensor, |
| 125 | + input: torch.Tensor, |
| 126 | + num_expert_tokens: torch.Tensor) -> torch.Tensor: |
| 127 | + |
| 128 | + num_expert_tokens_cpu = num_expert_tokens.clone() |
| 129 | + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") |
| 130 | + num_experts = num_expert_tokens.size(0) |
| 131 | + |
| 132 | + for e in range(num_experts): |
| 133 | + num_tokens = num_expert_tokens_cpu[e].item() |
| 134 | + out_part = output[e, :num_tokens, :] |
| 135 | + in_part = input[e, :num_tokens, :] |
| 136 | + torch.ops._C.silu_and_mul(out_part, in_part) |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.parametrize("num_experts", [16, 32]) |
| 140 | +@pytest.mark.parametrize("max_tokens_per_expert", [128]) |
| 141 | +@pytest.mark.parametrize("D", [128, 256]) |
| 142 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 143 | +def test_batched_silu_mul(num_experts: int, |
| 144 | + max_tokens_per_expert: int, |
| 145 | + D: int, |
| 146 | + dtype: torch.dtype): |
| 147 | + |
| 148 | + config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) |
| 149 | + tensors = BatchedSiluMulTensors.make_tensors(config) |
| 150 | + |
| 151 | + test_out = tensors.output |
| 152 | + ref_out = torch.zeros_like(test_out) |
| 153 | + |
| 154 | + ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) |
| 155 | + |
| 156 | + invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens) |
| 157 | + |
| 158 | + torch.testing.assert_close(test_out, ref_out) |
0 commit comments