Skip to content

Commit fcf8882

Browse files
SkqLiaoouqingliangclaude
authored
[Feature] Add avx-based kimi-k2 support (#1656)
* support Kimi-K2-Thinking original weight fix amx kernel bug * update k2 avx kernel. * feat: add CPUInfer write buffer task * [feat]: add kimi k2 cpu write buffer support - Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights - Fix down (w2) weight column-wise slicing for different TP configurations - Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp - Add comprehensive test cases for weight extraction validation - Ensure compatibility with Kimi model's MoE architecture * [fix]: correct write_weight_scale_to_buffer expert offset calculation Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios. Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * [feat]: add write buffer wrapper * [fix] fix comment --------- Co-authored-by: ouqingliang <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent c2b8c60 commit fcf8882

File tree

12 files changed

+2649
-34
lines changed

12 files changed

+2649
-34
lines changed
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
"""
4+
Benchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales.
5+
"""
6+
import json
7+
import math
8+
import os
9+
import platform
10+
import subprocess
11+
import sys
12+
import time
13+
14+
from tqdm import tqdm
15+
16+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
17+
18+
import kt_kernel_ext
19+
import torch
20+
21+
# Benchmark parameters (single MoE, no layer loop)
22+
expert_num = 384
23+
hidden_size = 7168
24+
intermediate_size = 2048
25+
max_len = 25600
26+
num_experts_per_tok = 8
27+
qlen = 1
28+
warm_up_iter = 1000
29+
test_iter = 5000
30+
k_group_size = 32
31+
32+
physical_to_logical_map = (
33+
torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
34+
)
35+
36+
worker_config = kt_kernel_ext.WorkerPoolConfig()
37+
worker_config.subpool_count = 2
38+
worker_config.subpool_numa_map = [0, 1]
39+
worker_config.subpool_thread_count = [40, 40]
40+
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
41+
42+
43+
def get_git_commit():
44+
result = {}
45+
try:
46+
commit = (
47+
subprocess.check_output(["git", "rev-parse", "HEAD"])
48+
.decode("utf-8")
49+
.strip()
50+
)
51+
commit_msg = (
52+
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
53+
.decode("utf-8")
54+
.strip()
55+
)
56+
result["commit"] = commit
57+
result["commit_message"] = commit_msg
58+
59+
dirty_output = (
60+
subprocess.check_output(["git", "status", "--porcelain"])
61+
.decode("utf-8")
62+
.strip()
63+
)
64+
if dirty_output:
65+
result["dirty"] = True
66+
result["dirty_files"] = dirty_output.splitlines()
67+
else:
68+
result["dirty"] = False
69+
except Exception as e:
70+
result["commit"] = None
71+
result["commit_message"] = None
72+
result["dirty"] = None
73+
result["error"] = str(e)
74+
return result
75+
76+
77+
def get_system_info():
78+
info = {}
79+
uname = platform.uname()
80+
info["system_name"] = uname.system
81+
info["node_name"] = uname.node
82+
83+
cpu_model = None
84+
if os.path.exists("/proc/cpuinfo"):
85+
try:
86+
with open("/proc/cpuinfo", "r") as f:
87+
for line in f:
88+
if "model name" in line:
89+
cpu_model = line.split(":", 1)[1].strip()
90+
break
91+
except Exception as e:
92+
cpu_model = f"Error: {e}"
93+
info["cpu_model"] = cpu_model
94+
95+
mem_total_gb = None
96+
if os.path.exists("/proc/meminfo"):
97+
try:
98+
with open("/proc/meminfo", "r") as f:
99+
for line in f:
100+
if "MemTotal" in line:
101+
mem_kb = float(line.split(":", 1)[1].split()[0])
102+
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
103+
break
104+
except Exception as e:
105+
mem_total_gb = f"Error: {e}"
106+
info["memory_size_GB"] = mem_total_gb
107+
108+
info["cpu_core_count"] = os.cpu_count()
109+
110+
sockets = set()
111+
if os.path.exists("/proc/cpuinfo"):
112+
try:
113+
with open("/proc/cpuinfo", "r") as f:
114+
for line in f:
115+
if "physical id" in line:
116+
sockets.add(line.split(":", 1)[1].strip())
117+
except Exception:
118+
sockets = set()
119+
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
120+
121+
return info
122+
123+
124+
script_path = os.path.abspath(__file__)
125+
script_dir = os.path.dirname(script_path)
126+
script_name = os.path.splitext(os.path.basename(script_path))[0]
127+
json_path = os.path.join(script_dir, script_name + ".jsonl")
128+
129+
130+
def record_results(result, filename=json_path):
131+
with open(filename, "a") as f:
132+
f.write(json.dumps(result) + "\n")
133+
134+
135+
def pack_to_int32(
136+
value: torch.Tensor, num_bits: int, packed_dim: int = 1
137+
) -> torch.Tensor:
138+
if value.dtype is not torch.int8:
139+
raise ValueError("Tensor must be torch.int8 before packing")
140+
if not (1 <= num_bits <= 8):
141+
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
142+
143+
offset = 1 << (num_bits - 1)
144+
value = (value + offset).to(torch.uint8)
145+
device = value.device
146+
147+
pack_factor = 32 // num_bits
148+
149+
if packed_dim == 0:
150+
value = value.transpose(0, 1)
151+
152+
rows, cols = value.shape
153+
padded_cols = math.ceil(cols / pack_factor) * pack_factor
154+
pad_len = padded_cols - cols
155+
156+
if pad_len > 0:
157+
value = torch.nn.functional.pad(value, (0, pad_len))
158+
159+
num_groups = padded_cols // pack_factor
160+
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
161+
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
162+
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
163+
164+
if packed_dim == 0:
165+
packed = packed.transpose(0, 1)
166+
167+
return packed
168+
169+
170+
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
171+
e, rows, cols = q.shape
172+
flat = q.view(e * rows, cols)
173+
packed = pack_to_int32(flat, num_bits)
174+
return packed.view(e, rows, -1).contiguous()
175+
176+
177+
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
178+
"""
179+
K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales.
180+
"""
181+
weights_f32 = weights.to(torch.float32)
182+
e, rows, cols = weights_f32.shape
183+
if cols % group_size != 0 or cols % 2 != 0:
184+
raise ValueError(
185+
f"cols ({cols}) must be divisible by group_size ({group_size}) and 2"
186+
)
187+
188+
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
189+
max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
190+
scales = (max_abs / 7.0).squeeze(-1)
191+
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
192+
q = q.view(e, rows, cols)
193+
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
194+
scales = scales.to(torch.bfloat16).contiguous().view(
195+
e, rows, cols // group_size
196+
).contiguous()
197+
return packed, scales
198+
199+
200+
def build_quantized_layer_weights():
201+
gate_proj = torch.randn(
202+
(expert_num, intermediate_size, hidden_size),
203+
dtype=torch.float32,
204+
device="cpu",
205+
).contiguous()
206+
up_proj = torch.randn(
207+
(expert_num, intermediate_size, hidden_size),
208+
dtype=torch.float32,
209+
device="cpu",
210+
).contiguous()
211+
down_proj = torch.randn(
212+
(expert_num, hidden_size, intermediate_size),
213+
dtype=torch.float32,
214+
device="cpu",
215+
).contiguous()
216+
217+
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
218+
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
219+
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
220+
221+
return {
222+
"gate_qweight": gate_q,
223+
"up_qweight": up_q,
224+
"down_qweight": down_q,
225+
"gate_scales": gate_scales,
226+
"up_scales": up_scales,
227+
"down_scales": down_scales,
228+
}
229+
230+
231+
def bench_k2_moe():
232+
with torch.inference_mode():
233+
bytes_per_elem = 0.5 + 2.0 / k_group_size
234+
235+
quant_data = build_quantized_layer_weights()
236+
config = kt_kernel_ext.moe.MOEConfig(
237+
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
238+
)
239+
config.max_len = max_len
240+
config.quant_config.bits = 4
241+
config.quant_config.group_size = k_group_size
242+
config.quant_config.zero_point = False
243+
244+
config.gate_proj = quant_data["gate_qweight"].data_ptr()
245+
config.up_proj = quant_data["up_qweight"].data_ptr()
246+
config.down_proj = quant_data["down_qweight"].data_ptr()
247+
248+
config.gate_scale = quant_data["gate_scales"].data_ptr()
249+
config.up_scale = quant_data["up_scales"].data_ptr()
250+
config.down_scale = quant_data["down_scales"].data_ptr()
251+
config.pool = CPUInfer.backend_
252+
253+
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
254+
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
255+
CPUInfer.sync()
256+
257+
gen_iter = 3000
258+
expert_ids = (
259+
torch.rand(gen_iter * qlen, expert_num, device="cpu")
260+
.argsort(dim=-1)[:, :num_experts_per_tok]
261+
.reshape(gen_iter, qlen * num_experts_per_tok)
262+
.contiguous()
263+
)
264+
weights = torch.rand(
265+
(gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu"
266+
).contiguous()
267+
input_tensor = torch.randn(
268+
(qlen, hidden_size), dtype=torch.bfloat16, device="cpu"
269+
).contiguous()
270+
output_tensor = torch.empty_like(input_tensor)
271+
bsz_tensor = torch.tensor([qlen], device="cpu")
272+
273+
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
274+
CPUInfer.submit(
275+
moe.forward_task(
276+
bsz_tensor.data_ptr(),
277+
num_experts_per_tok,
278+
expert_ids[i % gen_iter].data_ptr(),
279+
weights[i % gen_iter].data_ptr(),
280+
input_tensor.data_ptr(),
281+
output_tensor.data_ptr(),
282+
False,
283+
)
284+
)
285+
CPUInfer.sync()
286+
287+
start = time.perf_counter()
288+
for i in tqdm(range(test_iter), desc="Testing"):
289+
CPUInfer.submit(
290+
moe.forward_task(
291+
bsz_tensor.data_ptr(),
292+
num_experts_per_tok,
293+
expert_ids[i % gen_iter].data_ptr(),
294+
weights[i % gen_iter].data_ptr(),
295+
input_tensor.data_ptr(),
296+
output_tensor.data_ptr(),
297+
False,
298+
)
299+
)
300+
CPUInfer.sync()
301+
end = time.perf_counter()
302+
total_time = end - start
303+
304+
time_per_iter_us = total_time / test_iter * 1e6
305+
bandwidth = (
306+
hidden_size
307+
* intermediate_size
308+
* 3
309+
* num_experts_per_tok
310+
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
311+
* bytes_per_elem
312+
* test_iter
313+
/ total_time
314+
/ 1e9
315+
)
316+
flops = (
317+
hidden_size
318+
* intermediate_size
319+
* qlen
320+
* 3
321+
* num_experts_per_tok
322+
* 2
323+
* test_iter
324+
/ total_time
325+
/ 1e12
326+
)
327+
328+
print("Quant mode: int4_k2")
329+
print("Time(s): ", total_time)
330+
print("Iteration: ", test_iter)
331+
print("Time(us) per iteration: ", time_per_iter_us)
332+
print("Bandwidth: ", bandwidth, "GB/s")
333+
print("Flops: ", flops, "TFLOPS")
334+
print("")
335+
336+
result = {
337+
"quant_mode": "int4_k2",
338+
"total_time_seconds": total_time,
339+
"iterations": test_iter,
340+
"time_per_iteration_us": time_per_iter_us,
341+
"bandwidth_GBs": bandwidth,
342+
"flops_TFLOPS": flops,
343+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
344+
"test_parameters": {
345+
"expert_num": expert_num,
346+
"hidden_size": hidden_size,
347+
"intermediate_size": intermediate_size,
348+
"max_len": max_len,
349+
"num_experts_per_tok": num_experts_per_tok,
350+
"qlen": qlen,
351+
"warm_up_iter": warm_up_iter,
352+
"test_iter": test_iter,
353+
"k_group_size": k_group_size,
354+
"bytes_per_elem": bytes_per_elem,
355+
},
356+
}
357+
result.update(get_git_commit())
358+
result.update(get_system_info())
359+
record_results(result)
360+
361+
362+
if __name__ == "__main__":
363+
bench_k2_moe()

0 commit comments

Comments
 (0)