|
| 1 | +""" |
| 2 | +Performance benchmark for FP8 Per-Channel MoE kernel (GLM-4.7-FP8 style). |
| 3 | +
|
| 4 | +This benchmark measures the performance of the FP8 Per-Channel MoE operator with: |
| 5 | +- FP8 (E4M3) weights with per-channel scaling (one scale per output row) |
| 6 | +- BF16 activations |
| 7 | +- AVX-512 DPBF16 compute path |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import sys |
| 12 | +import time |
| 13 | +import json |
| 14 | +import subprocess |
| 15 | +import platform |
| 16 | + |
| 17 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) |
| 18 | + |
| 19 | +import torch |
| 20 | +from kt_kernel import kt_kernel_ext |
| 21 | +from tqdm import tqdm |
| 22 | + |
| 23 | +# Test parameters |
| 24 | +expert_num = 256 |
| 25 | +hidden_size = 7168 |
| 26 | +intermediate_size = 2048 |
| 27 | +num_experts_per_tok = 8 |
| 28 | +max_len = 25600 |
| 29 | + |
| 30 | +layer_num = 2 |
| 31 | +qlen = 1 |
| 32 | +warm_up_iter = 1000 |
| 33 | +test_iter = 3000 |
| 34 | +CPUINFER_PARAM = 80 |
| 35 | + |
| 36 | +CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM) |
| 37 | + |
| 38 | +# Result file path |
| 39 | +script_path = os.path.abspath(__file__) |
| 40 | +script_dir = os.path.dirname(script_path) |
| 41 | +json_path = os.path.join(script_dir, "bench_results.jsonl") |
| 42 | + |
| 43 | + |
| 44 | +def get_git_commit(): |
| 45 | + """Get current git commit info""" |
| 46 | + result = {} |
| 47 | + try: |
| 48 | + commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() |
| 49 | + commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip() |
| 50 | + result["commit"] = commit |
| 51 | + result["commit_message"] = commit_msg |
| 52 | + dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() |
| 53 | + result["dirty"] = bool(dirty_output) |
| 54 | + if dirty_output: |
| 55 | + result["dirty_files"] = dirty_output.splitlines() |
| 56 | + except Exception as e: |
| 57 | + result["commit"] = None |
| 58 | + result["error"] = str(e) |
| 59 | + return result |
| 60 | + |
| 61 | + |
| 62 | +def get_system_info(): |
| 63 | + """Get system information""" |
| 64 | + info = {} |
| 65 | + uname = platform.uname() |
| 66 | + info["system_name"] = uname.system |
| 67 | + info["node_name"] = uname.node |
| 68 | + |
| 69 | + cpu_model = None |
| 70 | + if os.path.exists("/proc/cpuinfo"): |
| 71 | + try: |
| 72 | + with open("/proc/cpuinfo", "r") as f: |
| 73 | + for line in f: |
| 74 | + if "model name" in line: |
| 75 | + cpu_model = line.split(":", 1)[1].strip() |
| 76 | + break |
| 77 | + except Exception: |
| 78 | + pass |
| 79 | + info["cpu_model"] = cpu_model |
| 80 | + info["cpu_core_count"] = os.cpu_count() |
| 81 | + return info |
| 82 | + |
| 83 | + |
| 84 | +def record_results(result, filename=json_path): |
| 85 | + """Append result to JSON file""" |
| 86 | + with open(filename, "a") as f: |
| 87 | + f.write(json.dumps(result) + "\n") |
| 88 | + |
| 89 | + |
| 90 | +def generate_fp8_perchannel_weights_direct(shape: tuple): |
| 91 | + """ |
| 92 | + Directly generate random FP8 weights and per-channel scales. |
| 93 | +
|
| 94 | + Args: |
| 95 | + shape: (expert_num, n, k) - weight tensor shape |
| 96 | +
|
| 97 | + Returns: |
| 98 | + fp8_weights: uint8 tensor with random FP8 E4M3 values |
| 99 | + scales: fp32 tensor with per-channel scales, shape [expert_num, n] |
| 100 | + """ |
| 101 | + e, n, k = shape |
| 102 | + |
| 103 | + # Directly generate random FP8 weights as uint8 |
| 104 | + # FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa |
| 105 | + fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous() |
| 106 | + |
| 107 | + # Generate random per-channel scales (one per output row) |
| 108 | + # Use reasonable scale range (e.g., 2^-8 to 2^8) |
| 109 | + exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous() |
| 110 | + scales = (2.0 ** exponents.float()).to(torch.float32).contiguous() |
| 111 | + |
| 112 | + return fp8_weights, scales |
| 113 | + |
| 114 | + |
| 115 | +def bench_fp8_perchannel_moe(): |
| 116 | + """Benchmark FP8 Per-Channel MoE performance""" |
| 117 | + with torch.inference_mode(): |
| 118 | + print("=" * 70) |
| 119 | + print("FP8 Per-Channel MoE Kernel Performance Benchmark") |
| 120 | + print("=" * 70) |
| 121 | + |
| 122 | + # Generate FP8 weights with per-channel scales |
| 123 | + print("\nGenerating FP8 weights with per-channel scales...") |
| 124 | + torch.manual_seed(42) |
| 125 | + gate_fp8, gate_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size)) |
| 126 | + up_fp8, up_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size)) |
| 127 | + down_fp8, down_scales = generate_fp8_perchannel_weights_direct((expert_num, hidden_size, intermediate_size)) |
| 128 | + |
| 129 | + physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous() |
| 130 | + |
| 131 | + # Build MoE layers |
| 132 | + print("Building FP8 Per-Channel MoE layers...") |
| 133 | + moes = [] |
| 134 | + for _ in tqdm(range(layer_num), desc="Initializing MOEs"): |
| 135 | + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) |
| 136 | + config.max_len = max_len |
| 137 | + config.quant_config.bits = 8 |
| 138 | + config.quant_config.group_size = 0 # Not used for per-channel |
| 139 | + config.quant_config.zero_point = False |
| 140 | + config.quant_config.per_channel = True # Enable per-channel mode |
| 141 | + |
| 142 | + config.gate_proj = gate_fp8.data_ptr() |
| 143 | + config.up_proj = up_fp8.data_ptr() |
| 144 | + config.down_proj = down_fp8.data_ptr() |
| 145 | + config.gate_scale = gate_scales.data_ptr() |
| 146 | + config.up_scale = up_scales.data_ptr() |
| 147 | + config.down_scale = down_scales.data_ptr() |
| 148 | + config.pool = CPUInfer.backend_ |
| 149 | + |
| 150 | + moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config) |
| 151 | + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) |
| 152 | + CPUInfer.sync() |
| 153 | + moes.append(moe) |
| 154 | + |
| 155 | + # Generate input data |
| 156 | + print("Generating input data...") |
| 157 | + gen_iter = 1000 |
| 158 | + expert_ids = ( |
| 159 | + torch.rand(gen_iter * qlen, expert_num, device="cpu") |
| 160 | + .argsort(dim=-1)[:, :num_experts_per_tok] |
| 161 | + .reshape(gen_iter, qlen * num_experts_per_tok) |
| 162 | + .contiguous() |
| 163 | + ) |
| 164 | + weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous() |
| 165 | + input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() |
| 166 | + output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() |
| 167 | + qlen_tensor = torch.tensor([qlen], dtype=torch.int32) |
| 168 | + |
| 169 | + # Warmup |
| 170 | + print(f"Warming up ({warm_up_iter} iterations)...") |
| 171 | + for i in tqdm(range(warm_up_iter), desc="Warm-up"): |
| 172 | + CPUInfer.submit( |
| 173 | + moes[i % layer_num].forward_task( |
| 174 | + qlen_tensor.data_ptr(), |
| 175 | + num_experts_per_tok, |
| 176 | + expert_ids[i % gen_iter].data_ptr(), |
| 177 | + weights[i % gen_iter].data_ptr(), |
| 178 | + input_tensor[i % layer_num].data_ptr(), |
| 179 | + output_tensor[i % layer_num].data_ptr(), |
| 180 | + False, |
| 181 | + ) |
| 182 | + ) |
| 183 | + CPUInfer.sync() |
| 184 | + |
| 185 | + # Benchmark |
| 186 | + print(f"Running benchmark ({test_iter} iterations)...") |
| 187 | + start = time.perf_counter() |
| 188 | + for i in tqdm(range(test_iter), desc="Testing"): |
| 189 | + CPUInfer.submit( |
| 190 | + moes[i % layer_num].forward_task( |
| 191 | + qlen_tensor.data_ptr(), |
| 192 | + num_experts_per_tok, |
| 193 | + expert_ids[i % gen_iter].data_ptr(), |
| 194 | + weights[i % gen_iter].data_ptr(), |
| 195 | + input_tensor[i % layer_num].data_ptr(), |
| 196 | + output_tensor[i % layer_num].data_ptr(), |
| 197 | + False, |
| 198 | + ) |
| 199 | + ) |
| 200 | + CPUInfer.sync() |
| 201 | + end = time.perf_counter() |
| 202 | + total_time = end - start |
| 203 | + |
| 204 | + # Calculate metrics |
| 205 | + time_per_iter_us = total_time / test_iter * 1e6 |
| 206 | + |
| 207 | + # FLOPS calculation: |
| 208 | + # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate) |
| 209 | + # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element) |
| 210 | + # For vector-matrix multiply (qlen=1): 2 * n * k per matrix |
| 211 | + flops_per_expert = ( |
| 212 | + 2 * intermediate_size * hidden_size # gate |
| 213 | + + 2 * intermediate_size * hidden_size # up |
| 214 | + + 2 * hidden_size * intermediate_size # down |
| 215 | + ) |
| 216 | + total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter |
| 217 | + tflops = total_flops / total_time / 1e12 |
| 218 | + |
| 219 | + # Bandwidth calculation (FP8 = 1 byte per element) |
| 220 | + bytes_per_elem = 1.0 |
| 221 | + # Weight memory: gate + up + down per expert |
| 222 | + bandwidth = ( |
| 223 | + hidden_size |
| 224 | + * intermediate_size |
| 225 | + * 3 |
| 226 | + * num_experts_per_tok |
| 227 | + * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen)) |
| 228 | + * bytes_per_elem |
| 229 | + * test_iter |
| 230 | + / total_time |
| 231 | + / 1e9 |
| 232 | + ) |
| 233 | + |
| 234 | + # Print results |
| 235 | + print("\n" + "=" * 70) |
| 236 | + print("Benchmark Results") |
| 237 | + print("=" * 70) |
| 238 | + print(f"Quant mode: FP8 (E4M3) with per-channel scaling") |
| 239 | + print(f"Total time: {total_time:.4f} s") |
| 240 | + print(f"Iterations: {test_iter}") |
| 241 | + print(f"Time per iteration: {time_per_iter_us:.2f} us") |
| 242 | + print(f"Bandwidth: {bandwidth:.2f} GB/s") |
| 243 | + print(f"TFLOPS: {tflops:.4f}") |
| 244 | + print("") |
| 245 | + |
| 246 | + # Record results |
| 247 | + result = { |
| 248 | + "test_name": os.path.basename(__file__), |
| 249 | + "quant_mode": "fp8_e4m3_perchannel", |
| 250 | + "total_time_seconds": total_time, |
| 251 | + "iterations": test_iter, |
| 252 | + "time_per_iteration_us": time_per_iter_us, |
| 253 | + "bandwidth_GBs": bandwidth, |
| 254 | + "flops_TFLOPS": tflops, |
| 255 | + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), |
| 256 | + "test_parameters": { |
| 257 | + "expert_num": expert_num, |
| 258 | + "hidden_size": hidden_size, |
| 259 | + "intermediate_size": intermediate_size, |
| 260 | + "num_experts_per_tok": num_experts_per_tok, |
| 261 | + "quant_type": "per_channel", |
| 262 | + "layer_num": layer_num, |
| 263 | + "qlen": qlen, |
| 264 | + "warm_up_iter": warm_up_iter, |
| 265 | + "test_iter": test_iter, |
| 266 | + "CPUInfer_parameter": CPUINFER_PARAM, |
| 267 | + }, |
| 268 | + } |
| 269 | + result.update(get_git_commit()) |
| 270 | + result.update(get_system_info()) |
| 271 | + record_results(result) |
| 272 | + |
| 273 | + return tflops, bandwidth |
| 274 | + |
| 275 | + |
| 276 | +if __name__ == "__main__": |
| 277 | + bench_fp8_perchannel_moe() |
0 commit comments