|
| 1 | +import argparse |
| 2 | +from typing import Optional, Literal |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +from flashinfer import ( |
| 6 | + fp4_quantize, |
| 7 | + mxfp8_quantize, |
| 8 | + next_positive_power_of_2, |
| 9 | +) |
| 10 | +from flashinfer.fused_moe import trtllm_fp4_block_scale_moe |
| 11 | +from flashinfer.autotuner import autotune |
| 12 | +from flashinfer.testing.utils import bench_gpu_time |
| 13 | +from flashinfer.utils import device_support_pdl |
| 14 | + |
| 15 | + |
| 16 | +def get_tile_tokens_dim(num_tokens, num_experts, top_k): |
| 17 | + # Factor to account for the imbalance of the experts. |
| 18 | + # factor equals to the |
| 19 | + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert |
| 20 | + # - 1.0 means perfect expert distribution. |
| 21 | + # - > 1.0 means some experts have more |
| 22 | + # tokens than the perfect distribution. |
| 23 | + # - < 1.0 does not make sense. |
| 24 | + imbalance_factor = 1.3 |
| 25 | + # Calculate the number of tokens per expert |
| 26 | + # assuming perfect distribution. |
| 27 | + num_tokens_per_expert = (num_tokens * top_k) // num_experts |
| 28 | + # Apply the imbalance factor. |
| 29 | + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) |
| 30 | + # And pad the number to the next power of 2. |
| 31 | + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) |
| 32 | + # Cap to 8-64 tokens per CTA tile |
| 33 | + # as it's the range supported by the kernel. |
| 34 | + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) |
| 35 | + return tile_tokens_dim |
| 36 | + |
| 37 | + |
| 38 | +def bench_trtllm_gen_fused_moe_autotuner( |
| 39 | + tune_max_num_tokens: Optional[int], |
| 40 | + quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], |
| 41 | + num_tokens: int, |
| 42 | + num_experts: int, |
| 43 | + hidden_size: int, |
| 44 | + intermediate_size: int, |
| 45 | + top_k: int, |
| 46 | + warmups: int, |
| 47 | + iterations: int, |
| 48 | +): |
| 49 | + device = torch.device("cuda:0") |
| 50 | + enable_pdl = device_support_pdl(device) |
| 51 | + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( |
| 52 | + torch.bfloat16 |
| 53 | + ) |
| 54 | + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( |
| 55 | + torch.bfloat16 |
| 56 | + ) |
| 57 | + if quant_mode == "NvFP4xNvFP4": |
| 58 | + hidden_states, hidden_states_scale = fp4_quantize( |
| 59 | + hidden_states, |
| 60 | + torch.tensor([448.0 * 6.0], device=device), |
| 61 | + sf_vec_size=16, |
| 62 | + sf_use_ue8m0=False, |
| 63 | + ) |
| 64 | + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( |
| 65 | + num_tokens, -1 |
| 66 | + ) |
| 67 | + hidden_states_global_scale = 1.0 / 448.0 / 6.0 |
| 68 | + elif quant_mode == "MxFP4xMxFP8": |
| 69 | + hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False) |
| 70 | + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( |
| 71 | + num_tokens, -1 |
| 72 | + ) |
| 73 | + hidden_states_global_scale = 1.0 |
| 74 | + else: # MxFP4xBf16 |
| 75 | + hidden_states_scale = None |
| 76 | + hidden_states_global_scale = 1.0 |
| 77 | + |
| 78 | + w13 = torch.randn( |
| 79 | + num_experts, intermediate_size * 2, hidden_size, device=device |
| 80 | + ).to(torch.bfloat16) |
| 81 | + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( |
| 82 | + torch.bfloat16 |
| 83 | + ) |
| 84 | + if quant_mode == "NvFP4xNvFP4": |
| 85 | + w13, w13_scale = fp4_quantize( |
| 86 | + w13, |
| 87 | + torch.tensor([448.0 * 6.0], device=device), |
| 88 | + sf_vec_size=16, |
| 89 | + sf_use_ue8m0=False, |
| 90 | + ) |
| 91 | + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( |
| 92 | + num_experts, intermediate_size * 2, -1 |
| 93 | + ) |
| 94 | + w2, w2_scale = fp4_quantize( |
| 95 | + w2, |
| 96 | + torch.tensor([448.0 * 6.0], device=device), |
| 97 | + sf_vec_size=16, |
| 98 | + sf_use_ue8m0=False, |
| 99 | + ) |
| 100 | + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( |
| 101 | + num_experts, hidden_size, -1 |
| 102 | + ) |
| 103 | + w13_global_scale = 1.0 / 448.0 / 6.0 |
| 104 | + w2_global_scale = 1.0 / 448.0 / 6.0 |
| 105 | + else: |
| 106 | + w13, w13_scale = fp4_quantize( |
| 107 | + w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True |
| 108 | + ) |
| 109 | + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( |
| 110 | + num_experts, intermediate_size * 2, -1 |
| 111 | + ) |
| 112 | + w2, w2_scale = fp4_quantize( |
| 113 | + w2, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True |
| 114 | + ) |
| 115 | + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( |
| 116 | + num_experts, hidden_size, -1 |
| 117 | + ) |
| 118 | + w13_global_scale = 1.0 |
| 119 | + w2_global_scale = 1.0 |
| 120 | + bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 |
| 121 | + bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 |
| 122 | + |
| 123 | + tile_tokens_dim = get_tile_tokens_dim(num_tokens, num_experts, top_k) |
| 124 | + output1_scale_scalar = torch.tensor( |
| 125 | + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device |
| 126 | + ) |
| 127 | + output1_scale_gate_scalar = torch.tensor( |
| 128 | + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device |
| 129 | + ) |
| 130 | + output2_scale_scalar = torch.tensor( |
| 131 | + [hidden_states_global_scale * w2_global_scale] * num_experts, device=device |
| 132 | + ) |
| 133 | + fn = lambda: trtllm_fp4_block_scale_moe( |
| 134 | + routing_logits, |
| 135 | + None, # routing_bias |
| 136 | + hidden_states, |
| 137 | + hidden_states_scale, |
| 138 | + w13, |
| 139 | + w13_scale, |
| 140 | + bias13, |
| 141 | + None, # gemm1_alpha |
| 142 | + None, # gemm1_beta |
| 143 | + None, # gemm1_clamp_limit |
| 144 | + w2, |
| 145 | + w2_scale, |
| 146 | + bias2, |
| 147 | + output1_scale_scalar, |
| 148 | + output1_scale_gate_scalar, |
| 149 | + output2_scale_scalar, |
| 150 | + num_experts, |
| 151 | + top_k, |
| 152 | + None, # n_group |
| 153 | + None, # topk_group |
| 154 | + intermediate_size, |
| 155 | + 0, # local_expert_offset |
| 156 | + num_experts, |
| 157 | + None, # routed_scaling_factor |
| 158 | + tile_tokens_dim, |
| 159 | + 1, |
| 160 | + True, |
| 161 | + enable_pdl, |
| 162 | + None, |
| 163 | + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, |
| 164 | + ) |
| 165 | + |
| 166 | + def bench(do_autotune): |
| 167 | + # warmup |
| 168 | + with autotune(do_autotune): |
| 169 | + for _ in range(warmups): |
| 170 | + fn() |
| 171 | + ms_list = bench_gpu_time( |
| 172 | + fn, |
| 173 | + repeat_iters=iterations, |
| 174 | + ) |
| 175 | + median_ms = np.median(ms_list) |
| 176 | + return median_ms |
| 177 | + |
| 178 | + ms = bench(do_autotune=False) |
| 179 | + ms_tuned = bench(do_autotune=True) |
| 180 | + print( |
| 181 | + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" |
| 182 | + ) |
| 183 | + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") |
| 184 | + |
| 185 | + |
| 186 | +if __name__ == "__main__": |
| 187 | + parser = argparse.ArgumentParser() |
| 188 | + parser.add_argument( |
| 189 | + "--quant-mode", |
| 190 | + type=str, |
| 191 | + default="MxFP4xMxFP8", |
| 192 | + choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], |
| 193 | + help="Quantization mode", |
| 194 | + ) |
| 195 | + parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") |
| 196 | + parser.add_argument( |
| 197 | + "--tune-max-num-tokens", |
| 198 | + type=int, |
| 199 | + default=None, |
| 200 | + help="Maximum number of tokens for tunning", |
| 201 | + ) |
| 202 | + parser.add_argument( |
| 203 | + "--num-experts", type=int, default=128, help="Number of experts" |
| 204 | + ) |
| 205 | + parser.add_argument("--hidden-size", type=int, default=3072, help="Hidden size") |
| 206 | + parser.add_argument( |
| 207 | + "--intermediate-size", type=int, default=3072, help="Intermediate size" |
| 208 | + ) |
| 209 | + parser.add_argument("--top-k", type=int, default=4, help="Top-k experts per token") |
| 210 | + parser.add_argument( |
| 211 | + "--warmups", type=int, default=100, help="Number of warmup iterations" |
| 212 | + ) |
| 213 | + parser.add_argument( |
| 214 | + "--iterations", type=int, default=100, help="Number of benchmark iterations" |
| 215 | + ) |
| 216 | + args = parser.parse_args() |
| 217 | + bench_trtllm_gen_fused_moe_autotuner( |
| 218 | + args.tune_max_num_tokens, |
| 219 | + args.quant_mode, |
| 220 | + args.num_tokens, |
| 221 | + args.num_experts, |
| 222 | + args.hidden_size, |
| 223 | + args.intermediate_size, |
| 224 | + args.top_k, |
| 225 | + args.warmups, |
| 226 | + args.iterations, |
| 227 | + ) |
0 commit comments