Skip to content

Commit 6277da4

Browse files
authored
support GLM 4.7 (#1791)
support GLM 4.7
1 parent 667030d commit 6277da4

File tree

14 files changed

+2336
-144
lines changed

14 files changed

+2336
-144
lines changed

kt-kernel/bench/bench_fp8_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
1818

1919
import torch
20-
import kt_kernel_ext
20+
from kt_kernel import kt_kernel_ext
2121
from tqdm import tqdm
2222

2323
# Test parameters
@@ -29,9 +29,9 @@
2929
max_len = 25600
3030

3131
layer_num = 2
32-
qlen = 1024
33-
warm_up_iter = 10
34-
test_iter = 30
32+
qlen = 1
33+
warm_up_iter = 1000
34+
test_iter = 3000
3535
CPUINFER_PARAM = 80
3636

3737
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""
2+
Performance benchmark for FP8 Per-Channel MoE kernel (GLM-4.7-FP8 style).
3+
4+
This benchmark measures the performance of the FP8 Per-Channel MoE operator with:
5+
- FP8 (E4M3) weights with per-channel scaling (one scale per output row)
6+
- BF16 activations
7+
- AVX-512 DPBF16 compute path
8+
"""
9+
10+
import os
11+
import sys
12+
import time
13+
import json
14+
import subprocess
15+
import platform
16+
17+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
18+
19+
import torch
20+
from kt_kernel import kt_kernel_ext
21+
from tqdm import tqdm
22+
23+
# Test parameters
24+
expert_num = 256
25+
hidden_size = 7168
26+
intermediate_size = 2048
27+
num_experts_per_tok = 8
28+
max_len = 25600
29+
30+
layer_num = 2
31+
qlen = 1
32+
warm_up_iter = 1000
33+
test_iter = 3000
34+
CPUINFER_PARAM = 80
35+
36+
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
37+
38+
# Result file path
39+
script_path = os.path.abspath(__file__)
40+
script_dir = os.path.dirname(script_path)
41+
json_path = os.path.join(script_dir, "bench_results.jsonl")
42+
43+
44+
def get_git_commit():
45+
"""Get current git commit info"""
46+
result = {}
47+
try:
48+
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
49+
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
50+
result["commit"] = commit
51+
result["commit_message"] = commit_msg
52+
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
53+
result["dirty"] = bool(dirty_output)
54+
if dirty_output:
55+
result["dirty_files"] = dirty_output.splitlines()
56+
except Exception as e:
57+
result["commit"] = None
58+
result["error"] = str(e)
59+
return result
60+
61+
62+
def get_system_info():
63+
"""Get system information"""
64+
info = {}
65+
uname = platform.uname()
66+
info["system_name"] = uname.system
67+
info["node_name"] = uname.node
68+
69+
cpu_model = None
70+
if os.path.exists("/proc/cpuinfo"):
71+
try:
72+
with open("/proc/cpuinfo", "r") as f:
73+
for line in f:
74+
if "model name" in line:
75+
cpu_model = line.split(":", 1)[1].strip()
76+
break
77+
except Exception:
78+
pass
79+
info["cpu_model"] = cpu_model
80+
info["cpu_core_count"] = os.cpu_count()
81+
return info
82+
83+
84+
def record_results(result, filename=json_path):
85+
"""Append result to JSON file"""
86+
with open(filename, "a") as f:
87+
f.write(json.dumps(result) + "\n")
88+
89+
90+
def generate_fp8_perchannel_weights_direct(shape: tuple):
91+
"""
92+
Directly generate random FP8 weights and per-channel scales.
93+
94+
Args:
95+
shape: (expert_num, n, k) - weight tensor shape
96+
97+
Returns:
98+
fp8_weights: uint8 tensor with random FP8 E4M3 values
99+
scales: fp32 tensor with per-channel scales, shape [expert_num, n]
100+
"""
101+
e, n, k = shape
102+
103+
# Directly generate random FP8 weights as uint8
104+
# FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa
105+
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
106+
107+
# Generate random per-channel scales (one per output row)
108+
# Use reasonable scale range (e.g., 2^-8 to 2^8)
109+
exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous()
110+
scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()
111+
112+
return fp8_weights, scales
113+
114+
115+
def bench_fp8_perchannel_moe():
116+
"""Benchmark FP8 Per-Channel MoE performance"""
117+
with torch.inference_mode():
118+
print("=" * 70)
119+
print("FP8 Per-Channel MoE Kernel Performance Benchmark")
120+
print("=" * 70)
121+
122+
# Generate FP8 weights with per-channel scales
123+
print("\nGenerating FP8 weights with per-channel scales...")
124+
torch.manual_seed(42)
125+
gate_fp8, gate_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))
126+
up_fp8, up_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))
127+
down_fp8, down_scales = generate_fp8_perchannel_weights_direct((expert_num, hidden_size, intermediate_size))
128+
129+
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
130+
131+
# Build MoE layers
132+
print("Building FP8 Per-Channel MoE layers...")
133+
moes = []
134+
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
135+
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
136+
config.max_len = max_len
137+
config.quant_config.bits = 8
138+
config.quant_config.group_size = 0 # Not used for per-channel
139+
config.quant_config.zero_point = False
140+
config.quant_config.per_channel = True # Enable per-channel mode
141+
142+
config.gate_proj = gate_fp8.data_ptr()
143+
config.up_proj = up_fp8.data_ptr()
144+
config.down_proj = down_fp8.data_ptr()
145+
config.gate_scale = gate_scales.data_ptr()
146+
config.up_scale = up_scales.data_ptr()
147+
config.down_scale = down_scales.data_ptr()
148+
config.pool = CPUInfer.backend_
149+
150+
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)
151+
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
152+
CPUInfer.sync()
153+
moes.append(moe)
154+
155+
# Generate input data
156+
print("Generating input data...")
157+
gen_iter = 1000
158+
expert_ids = (
159+
torch.rand(gen_iter * qlen, expert_num, device="cpu")
160+
.argsort(dim=-1)[:, :num_experts_per_tok]
161+
.reshape(gen_iter, qlen * num_experts_per_tok)
162+
.contiguous()
163+
)
164+
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
165+
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
166+
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
167+
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
168+
169+
# Warmup
170+
print(f"Warming up ({warm_up_iter} iterations)...")
171+
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
172+
CPUInfer.submit(
173+
moes[i % layer_num].forward_task(
174+
qlen_tensor.data_ptr(),
175+
num_experts_per_tok,
176+
expert_ids[i % gen_iter].data_ptr(),
177+
weights[i % gen_iter].data_ptr(),
178+
input_tensor[i % layer_num].data_ptr(),
179+
output_tensor[i % layer_num].data_ptr(),
180+
False,
181+
)
182+
)
183+
CPUInfer.sync()
184+
185+
# Benchmark
186+
print(f"Running benchmark ({test_iter} iterations)...")
187+
start = time.perf_counter()
188+
for i in tqdm(range(test_iter), desc="Testing"):
189+
CPUInfer.submit(
190+
moes[i % layer_num].forward_task(
191+
qlen_tensor.data_ptr(),
192+
num_experts_per_tok,
193+
expert_ids[i % gen_iter].data_ptr(),
194+
weights[i % gen_iter].data_ptr(),
195+
input_tensor[i % layer_num].data_ptr(),
196+
output_tensor[i % layer_num].data_ptr(),
197+
False,
198+
)
199+
)
200+
CPUInfer.sync()
201+
end = time.perf_counter()
202+
total_time = end - start
203+
204+
# Calculate metrics
205+
time_per_iter_us = total_time / test_iter * 1e6
206+
207+
# FLOPS calculation:
208+
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
209+
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
210+
# For vector-matrix multiply (qlen=1): 2 * n * k per matrix
211+
flops_per_expert = (
212+
2 * intermediate_size * hidden_size # gate
213+
+ 2 * intermediate_size * hidden_size # up
214+
+ 2 * hidden_size * intermediate_size # down
215+
)
216+
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
217+
tflops = total_flops / total_time / 1e12
218+
219+
# Bandwidth calculation (FP8 = 1 byte per element)
220+
bytes_per_elem = 1.0
221+
# Weight memory: gate + up + down per expert
222+
bandwidth = (
223+
hidden_size
224+
* intermediate_size
225+
* 3
226+
* num_experts_per_tok
227+
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
228+
* bytes_per_elem
229+
* test_iter
230+
/ total_time
231+
/ 1e9
232+
)
233+
234+
# Print results
235+
print("\n" + "=" * 70)
236+
print("Benchmark Results")
237+
print("=" * 70)
238+
print(f"Quant mode: FP8 (E4M3) with per-channel scaling")
239+
print(f"Total time: {total_time:.4f} s")
240+
print(f"Iterations: {test_iter}")
241+
print(f"Time per iteration: {time_per_iter_us:.2f} us")
242+
print(f"Bandwidth: {bandwidth:.2f} GB/s")
243+
print(f"TFLOPS: {tflops:.4f}")
244+
print("")
245+
246+
# Record results
247+
result = {
248+
"test_name": os.path.basename(__file__),
249+
"quant_mode": "fp8_e4m3_perchannel",
250+
"total_time_seconds": total_time,
251+
"iterations": test_iter,
252+
"time_per_iteration_us": time_per_iter_us,
253+
"bandwidth_GBs": bandwidth,
254+
"flops_TFLOPS": tflops,
255+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
256+
"test_parameters": {
257+
"expert_num": expert_num,
258+
"hidden_size": hidden_size,
259+
"intermediate_size": intermediate_size,
260+
"num_experts_per_tok": num_experts_per_tok,
261+
"quant_type": "per_channel",
262+
"layer_num": layer_num,
263+
"qlen": qlen,
264+
"warm_up_iter": warm_up_iter,
265+
"test_iter": test_iter,
266+
"CPUInfer_parameter": CPUINFER_PARAM,
267+
},
268+
}
269+
result.update(get_git_commit())
270+
result.update(get_system_info())
271+
record_results(result)
272+
273+
return tflops, bandwidth
274+
275+
276+
if __name__ == "__main__":
277+
bench_fp8_perchannel_moe()

0 commit comments

Comments
 (0)