|
| 1 | +""" |
| 2 | +Copyright (c) 2024 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch.nn import functional as F |
| 19 | +from triton.testing import do_bench |
| 20 | + |
| 21 | +import flashinfer |
| 22 | +import flashinfer.fused_moe as fused_moe |
| 23 | +from flashinfer import fp4_quantize |
| 24 | + |
| 25 | +BATCH_SIZES = [ |
| 26 | + 1, |
| 27 | + 2, |
| 28 | + 4, |
| 29 | + 8, |
| 30 | + 16, |
| 31 | + 24, |
| 32 | + 32, |
| 33 | + 48, |
| 34 | + 64, |
| 35 | + 96, |
| 36 | + 128, |
| 37 | + 256, |
| 38 | + 512, |
| 39 | + 1024, |
| 40 | + 1536, |
| 41 | + 2048, |
| 42 | + 3072, |
| 43 | + 4096, |
| 44 | +] |
| 45 | + |
| 46 | +configs = [] |
| 47 | +hidden_size = 7168 |
| 48 | +num_experts = [32, 256] |
| 49 | +top_k = [8] |
| 50 | +intermediate_size = [256, 2048] |
| 51 | +FLOAT4_E2M1_MAX = 6.0 |
| 52 | +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max |
| 53 | +FP8_DTYPE = torch.float8_e4m3fn |
| 54 | + |
| 55 | +test_configs = [ |
| 56 | + { |
| 57 | + "hidden_size": 7168, |
| 58 | + "num_experts": 256, |
| 59 | + "top_k": 8, |
| 60 | + "intermediate_size": 256, |
| 61 | + }, |
| 62 | + { |
| 63 | + "hidden_size": 7168, |
| 64 | + "num_experts": 32, |
| 65 | + "top_k": 8, |
| 66 | + "intermediate_size": 2048, |
| 67 | + }, |
| 68 | +] |
| 69 | + |
| 70 | + |
| 71 | +def compute_routing( |
| 72 | + router_logits: torch.Tensor, top_k: int |
| 73 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 74 | + """ |
| 75 | + Compute routing weights and selected experts from router logits. |
| 76 | +
|
| 77 | + Args: |
| 78 | + router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] |
| 79 | + top_k (int): Number of experts to route to per token |
| 80 | +
|
| 81 | + Returns: |
| 82 | + tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
| 83 | + - routing_weights: Expert weights of shape [batch_size, top_k] |
| 84 | + - selected_experts: Expert indices of shape [batch_size, top_k] |
| 85 | + """ |
| 86 | + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| 87 | + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) |
| 88 | + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| 89 | + routing_weights = routing_weights.float() |
| 90 | + return routing_weights, selected_experts |
| 91 | + |
| 92 | + |
| 93 | +def bench_cutlass_fused_moe( |
| 94 | + batch_size, |
| 95 | + hidden_size, |
| 96 | + num_experts, |
| 97 | + top_k, |
| 98 | + intermediate_size, |
| 99 | +): |
| 100 | + torch.manual_seed(42) |
| 101 | + quant_blocksize = 16 |
| 102 | + round_up = lambda x, y: (x + y - 1) // y * y |
| 103 | + e = num_experts |
| 104 | + m = batch_size |
| 105 | + n = intermediate_size |
| 106 | + k = hidden_size |
| 107 | + otype = torch.bfloat16 |
| 108 | + wtype = torch.float8_e4m3fn |
| 109 | + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 |
| 110 | + w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous() |
| 111 | + |
| 112 | + sf_w1_2n = round_up(2 * n, 128) |
| 113 | + sf_w1_k = round_up(k // quant_blocksize, 4) |
| 114 | + w1_blockscale = torch.empty( |
| 115 | + (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn |
| 116 | + ) |
| 117 | + w1_blockscale_cutlass = torch.empty( |
| 118 | + (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn |
| 119 | + ) |
| 120 | + |
| 121 | + w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 |
| 122 | + sf_w2_k = round_up(k, 128) |
| 123 | + sf_w2_n = round_up(n // quant_blocksize, 4) |
| 124 | + w2_blockscale = torch.empty( |
| 125 | + (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn |
| 126 | + ) |
| 127 | + w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) |
| 128 | + w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) |
| 129 | + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) |
| 130 | + w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) |
| 131 | + w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) |
| 132 | + |
| 133 | + for expert in range(e): |
| 134 | + w1_amax = torch.abs(w1).max().to(torch.float32) |
| 135 | + w2_amax = torch.abs(w2).max().to(torch.float32) |
| 136 | + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax |
| 137 | + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax |
| 138 | + |
| 139 | + w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) |
| 140 | + |
| 141 | + w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize( |
| 142 | + w1_cutlass[expert], w1_gs[expert] |
| 143 | + ) |
| 144 | + |
| 145 | + w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) |
| 146 | + |
| 147 | + x = torch.randn(m, k, dtype=otype).cuda() |
| 148 | + a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to( |
| 149 | + torch.float32 |
| 150 | + ).cuda() |
| 151 | + a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) |
| 152 | + a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) |
| 153 | + router_logits = torch.randn(m, e, dtype=otype).cuda() |
| 154 | + routing_weights, selected_experts = compute_routing(router_logits, top_k) |
| 155 | + |
| 156 | + flash_output = torch.zeros_like(x) |
| 157 | + |
| 158 | + quant_scales = [ |
| 159 | + a1_gs, |
| 160 | + w1_blockscale.view(torch.int32), |
| 161 | + 1.0 / (a1_gs * w1_gs), |
| 162 | + a2_gs, |
| 163 | + w2_blockscale.view(torch.int32), |
| 164 | + 1.0 / (a2_gs * w2_gs), |
| 165 | + ] |
| 166 | + hidden_states = x |
| 167 | + hidden_states, input_sf = fp4_quantize(x, a1_gs) |
| 168 | + repeats = 3 |
| 169 | + from flashinfer.autotuner import AutoTuner, autotune |
| 170 | + |
| 171 | + AutoTuner.get().clear_cache() |
| 172 | + with torch.inference_mode(), autotune(): |
| 173 | + for _ in range(2): |
| 174 | + _ = fused_moe.cutlass_fused_moe( |
| 175 | + hidden_states, |
| 176 | + selected_experts.to(torch.int), |
| 177 | + routing_weights, |
| 178 | + w1_q.contiguous().view(torch.long), |
| 179 | + w2_q.contiguous().view(torch.long), |
| 180 | + otype, |
| 181 | + quant_scales=quant_scales, |
| 182 | + input_sf=input_sf, |
| 183 | + output=flash_output, |
| 184 | + ) |
| 185 | + ms = do_bench( |
| 186 | + lambda: fused_moe.cutlass_fused_moe( |
| 187 | + hidden_states, |
| 188 | + selected_experts.to(torch.int), |
| 189 | + routing_weights, |
| 190 | + w1_q.contiguous().view(torch.long), |
| 191 | + w2_q.contiguous().view(torch.long), |
| 192 | + otype, |
| 193 | + quant_scales=quant_scales, |
| 194 | + input_sf=input_sf, |
| 195 | + output=flash_output, |
| 196 | + ) |
| 197 | + ) |
| 198 | + print( |
| 199 | + f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" |
| 200 | + ) |
| 201 | + print(f"execution time: {ms}ms") |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == "__main__": |
| 205 | + for config in test_configs: |
| 206 | + hidden_size = config["hidden_size"] |
| 207 | + num_experts = config["num_experts"] |
| 208 | + top_k = config["top_k"] |
| 209 | + intermediate_size = config["intermediate_size"] |
| 210 | + for batch_size in BATCH_SIZES: |
| 211 | + bench_cutlass_fused_moe( |
| 212 | + batch_size, |
| 213 | + hidden_size, |
| 214 | + num_experts, |
| 215 | + top_k, |
| 216 | + intermediate_size, |
| 217 | + ) |
0 commit comments