Skip to content

Commit 1374b98

Browse files
KMSorSMSmrhaoxx
andauthored
[feat](moe_kernel): add amd blis support (int8) (#1600)
* [feat]: init amd adaption * [feat]: add blis support * [fix]: fix setup and moe kernel warpper * [fix](setup.py): support rebuild with cache and import kt_kernel works fine * [feat]: add moe_kernel converter for amd and implement the load method(haven't tested yet) * [feat](moe_kernel/moe.hpp): delete unused memory when using save * [fix](moe_kernel): update PLAIN for pack * [fix](moe_kernel): rm printf debug * [fix](moe_kernel): skip gpu experts * [fix](moe_kernel/moe.hpp): update include memory path * [feat](moe_kernel/moe.hpp): support expert deferral * [feat]: finish amd --------- Co-authored-by: mrhaoxx <[email protected]>
1 parent fef6dd9 commit 1374b98

File tree

14 files changed

+658
-241
lines changed

14 files changed

+658
-241
lines changed

kt-kernel/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ if(NOT DEFINED CLANG_FORMAT_BIN)
495495
)
496496
endif()
497497
if(NOT CLANG_FORMAT_BIN)
498-
message(WARNING "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
498+
message(WARNING "ONLY for developer: clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
499499
else()
500500
execute_process(
501501
COMMAND ${CLANG_FORMAT_BIN} --version

kt-kernel/CMakePresets.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@
3939
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
4040
"KTRANSFORMERS_USE_CUDA": "ON"
4141
}
42+
},
43+
{
44+
"name": "amd",
45+
"displayName": "amd_platform",
46+
"description": "for amd platform",
47+
"cacheVariables": {
48+
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
49+
"LLAMA_AVX512": "OFF",
50+
"LLAMA_AVX2": "ON",
51+
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "OFF",
52+
"KTRANSFORMERS_USE_CUDA": "ON",
53+
"KTRANSFORMERS_CPU_MOE_AMD": "ON",
54+
"KTRANSFORMERS_CPU_MOE_KERNEL": "ON"
55+
}
4256
}
4357

4458
]

kt-kernel/README.md

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,32 @@
22

33
High-performance kernel operations for KTransformers, featuring CPU-optimized MoE inference with AMX, AVX, KML and blis (amd library) support.
44

5-
- [KT-Kernel](#kt-kernel)
6-
- [Note](#note)
7-
- [Features](#features)
8-
- [Installation](#installation)
9-
- [Prerequisites](#prerequisites)
10-
- [Quick Installation (Recommended)](#quick-installation-recommended)
11-
- [Manual Configuration (Advanced)](#manual-configuration-advanced)
12-
- [Verification](#verification)
13-
- [Integration with SGLang](#integration-with-sglang)
14-
- [Installation Steps](#installation-steps)
15-
- [1. Install SGLang](#1-install-sglang)
16-
- [2. Prepare Weights](#2-prepare-weights)
17-
- [3. Launch SGLang Server](#3-launch-sglang-server)
18-
- [Complete Example: Qwen3-30B-A3B](#complete-example-qwen3-30b-a3b)
19-
- [Option A: AMX Backend (AMXINT8)](#option-a-amx-backend-amxint8)
20-
- [Option B: LLAMAFILE Backend (GGUF)](#option-b-llamafile-backend-gguf)
21-
- [KT-Kernel Parameters](#kt-kernel-parameters)
22-
- [Direct Python API Usage](#direct-python-api-usage)
23-
- [Advanced Options](#advanced-options)
24-
- [Build Configuration](#build-configuration)
25-
- [Manual Installation](#manual-installation)
26-
- [1. Install System Dependencies](#1-install-system-dependencies)
27-
- [2. Set Build Configuration](#2-set-build-configuration)
28-
- [3. Build and Install](#3-build-and-install)
29-
- [Error Troubleshooting](#error-troubleshooting)
30-
- [CUDA Not Found](#cuda-not-found)
31-
- [hwloc Not Found](#hwloc-not-found)
32-
- [Weight Quantization](#weight-quantization)
33-
- [Before Commit!](#before-commit)
5+
- [Note](#note)
6+
- [Features](#features)
7+
- [Installation](#installation)
8+
- [Prerequisites](#prerequisites)
9+
- [Quick Installation (Recommended)](#quick-installation-recommended)
10+
- [Manual Configuration (Advanced)](#manual-configuration-advanced)
11+
- [Verification](#verification)
12+
- [Integration with SGLang](#integration-with-sglang)
13+
- [Installation Steps](#installation-steps)
14+
- [Complete Example: Qwen3-30B-A3B](#complete-example-qwen3-30b-a3b)
15+
- [KT-Kernel Parameters](#kt-kernel-parameters)
16+
- [Direct Python API Usage](#direct-python-api-usage)
17+
- [Advanced Options](#advanced-options)
18+
- [Build Configuration](#build-configuration)
19+
- [Manual Installation](#manual-installation)
20+
- [Error Troubleshooting](#error-troubleshooting)
21+
- [CUDA Not Found](#cuda-not-found)
22+
- [hwloc Not Found](#hwloc-not-found)
23+
- [Weight Quantization](#weight-quantization)
24+
- [Before Commit!](#before-commit)
3425
## Note
3526

3627
**Current Support Status:**
3728
-**Intel CPUs with AMX**: Fully supported (using weights converted to INT4/INT8 format)
3829
-**Universal CPU (llamafile backend)**: Supported (using GGUF-format weights)
39-
- ⚠️ **AMD CPUs with BLIS**: In progress, not yet fully integrated
30+
- **AMD CPUs with BLIS**: Supported (for int8 prefill & decode)
4031

4132
## Features
4233

@@ -145,7 +136,7 @@ python scripts/convert_cpu_weights.py \
145136
--input-path /path/to/model \
146137
--input-type bf16 \
147138
--output /path/to/cpu-weights \
148-
--quant-method int8 # or int4
139+
--quant-method int8 # or int4 or moe_int8 (for amd now)
149140
```
150141

151142
- `--input-path`: Path to GPU-side original weights

kt-kernel/README_zh.md

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,33 @@
22

33
高性能 KTransformers 内核库,提供面向 CPU 的高效 MoE 推理内核,支持 AMX 和 AVX 等后端。
44

5-
- [KT-Kernel](#kt-kernel)
6-
- [说明](#说明)
7-
- [特性](#特性)
8-
- [安装](#安装)
9-
- [先决条件](#先决条件)
10-
- [快速安装(推荐)](#快速安装推荐)
11-
- [手动配置(进阶)](#手动配置进阶)
12-
- [验证安装](#验证安装)
13-
- [与 SGLang 集成](#与-sglang-集成)
14-
- [安装步骤](#安装步骤)
15-
- [1. 安装 SGLang](#1-安装-sglang)
16-
- [2. 准备权重](#2-准备权重)
17-
- [3. 启动 SGLang Server](#3-启动-sglang-server)
18-
- [完整示例:Qwen3-30B-A3B](#完整示例qwen3-30b-a3b)
19-
- [方案 A:AMX 后端(AMXINT8)](#方案-aamx-后端amxint8)
20-
- [方案 B:LLAMAFILE 后端(GGUF)](#方案-bllamafile-后端gguf)
21-
- [KT-Kernel 参数](#kt-kernel-参数)
22-
- [直接使用 Python API](#直接使用-python-api)
23-
- [高级选项](#高级选项)
24-
- [构建配置](#构建配置)
25-
- [手动安装](#手动安装)
26-
- [1. 安装系统依赖](#1-安装系统依赖)
27-
- [2. 配置构建参数](#2-配置构建参数)
28-
- [3. 构建并安装](#3-构建并安装)
29-
- [错误排查](#错误排查)
30-
- [找不到 CUDA](#找不到-cuda)
31-
- [找不到 hwloc](#找不到-hwloc)
32-
- [权重量化](#权重量化)
33-
- [提交前必读](#提交前必读)
5+
- [说明](#说明)
6+
- [特性](#特性)
7+
- [安装](#安装)
8+
- [先决条件](#先决条件)
9+
- [快速安装(推荐)](#快速安装推荐)
10+
- [手动配置(进阶)](#手动配置进阶)
11+
- [验证安装](#验证安装)
12+
- [与 SGLang 集成](#与-sglang-集成)
13+
- [安装步骤](#安装步骤)
14+
- [完整示例:Qwen3-30B-A3B](#完整示例qwen3-30b-a3b)
15+
- [KT-Kernel 参数](#kt-kernel-参数)
16+
- [直接使用 Python API](#直接使用-python-api)
17+
- [高级选项](#高级选项)
18+
- [构建配置](#构建配置)
19+
- [手动安装](#手动安装)
20+
- [错误排查](#错误排查)
21+
- [找不到 CUDA](#找不到-cuda)
22+
- [找不到 hwloc](#找不到-hwloc)
23+
- [权重量化](#权重量化)
24+
- [提交前必读](#提交前必读)
3425

3526
## 说明
3627

3728
**当前支持状态:**
3829
-**带 AMX 的 Intel CPU**:已支持(基于转换为 INT4/INT8 格式的权重)
3930
-**通用 CPU(llamafile 后端)**:已支持(基于 GGUF 格式的权重)
40-
- ⚠️ **带 BLIS 的 AMD CPU**进行中,尚未完全集成
31+
- **带 BLIS 的 AMD CPU**已支持(int8 的 prefill 和 decode)
4132

4233
## 特性
4334

@@ -149,7 +140,7 @@ python scripts/convert_cpu_weights.py \
149140
--input-path /path/to/model \
150141
--input-type bf16 \
151142
--output /path/to/cpu-weights \
152-
--quant-method int8 # 或 int4
143+
--quant-method int8 # 或 int4 或 moe_int8(用于 amd 的)
153144
```
154145

155146
- `--input-path`:GPU 侧原始权重路径

kt-kernel/operators/amx/test/mmq-test.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,9 +2376,7 @@ bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {
23762376
static thread_local bool is_first_time = true;
23772377
if (is_first_time) {
23782378
#pragma omp single
2379-
{
2380-
ggml_amx_init();
2381-
}
2379+
{ ggml_amx_init(); }
23822380

23832381
// load tile config
23842382
ggml_tile_config_init();

kt-kernel/operators/amx/test/mmq.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,9 +2372,7 @@ bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {
23722372
static thread_local bool is_first_time = true;
23732373
if (is_first_time) {
23742374
#pragma omp single
2375-
{
2376-
ggml_amx_init();
2377-
}
2375+
{ ggml_amx_init(); }
23782376

23792377
// load tile config
23802378
ggml_tile_config_init();

kt-kernel/operators/llamafile/mla.hpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
// #include <utility>
1515
// #include <vector>
1616

17-
// #define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
18-
// do { \
19-
// if ((what) < (threshold)) { \
20-
// for (int i = 0; i < (var); i++) { \
21-
// (fn)(i); \
22-
// } \
23-
// } else { \
24-
// pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
25-
// } \
17+
// #define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
18+
// do { \
19+
// if ((what) < (threshold)) { \
20+
// for (int i = 0; i < (var); i++) { \
21+
// (fn)(i); \
22+
// } \
23+
// } else { \
24+
// pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
25+
// } \
2626
// } while (0)
2727

2828
// #define VEC_DOT_TYPE(type) (ggml_internal_get_type_traits((ggml_type)(type)).vec_dot_type)
@@ -31,19 +31,20 @@
3131
// #define QUANT_OFFSET(ptr, type, n, n_elements) \
3232
// (offset_pointer((ptr), (size_t)(n) * QUANT_BLCK_SIZE((n_elements), (type))))
3333

34-
// #define LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(m, n, k, a, a_type, b, b_col, c, c_col) \
35-
// do { \
36-
// llamafile_sgemm((m), (n), QUANT_BLCK_COUNT((k), (a_type)), (a), QUANT_BLCK_COUNT((k), (a_type)), \
37-
// QUANT_OFFSET((b), VEC_DOT_TYPE((a_type)), (b_col), (k)), \
38-
// QUANT_BLCK_COUNT((k), VEC_DOT_TYPE((a_type))), offset_pointer((c), (c_col) * (m) * sizeof(float)), \
39-
// (k), 0, 1, GGML_TASK_TYPE_COMPUTE, (a_type), VEC_DOT_TYPE((a_type)), GGML_TYPE_F32, \
40-
// GGML_PREC_DEFAULT); \
34+
// #define LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(m, n, k, a, a_type, b, b_col, c, c_col) \
35+
// do { \
36+
// llamafile_sgemm((m), (n), QUANT_BLCK_COUNT((k), (a_type)), (a), QUANT_BLCK_COUNT((k), (a_type)), \
37+
// QUANT_OFFSET((b), VEC_DOT_TYPE((a_type)), (b_col), (k)), \
38+
// QUANT_BLCK_COUNT((k), VEC_DOT_TYPE((a_type))), offset_pointer((c), (c_col) * (m) *
39+
// sizeof(float)), \
40+
// (k), 0, 1, GGML_TASK_TYPE_COMPUTE, (a_type), VEC_DOT_TYPE((a_type)), GGML_TYPE_F32, \
41+
// GGML_PREC_DEFAULT); \
4142
// } while (0)
4243

43-
// #define LLAMAFILE_SGEMM_MATMUL_F32(m, n, k, a, lda, b, ldb, c, ldc) \
44-
// do { \
45-
// llamafile_sgemm((m), (n), (k), (a), (lda), (b), (ldb), (c), (ldc), 0, 1, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F32, \
46-
// GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT); \
44+
// #define LLAMAFILE_SGEMM_MATMUL_F32(m, n, k, a, lda, b, ldb, c, ldc) \
45+
// do { \
46+
// llamafile_sgemm((m), (n), (k), (a), (lda), (b), (ldb), (c), (ldc), 0, 1, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F32, \
47+
// GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT); \
4748
// } while (0)
4849

4950
// // bool decide_absorb(size_t a,int a_type,size_t b,int b_type,size_t c,int c_type,size_t d,int d_type){

kt-kernel/operators/moe_kernel/la/kernel.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ struct GemmKernelInt8 {
340340
static inline const int PACK_SIZE_M = 8;
341341
static inline const int PACK_SIZE_K = 32;
342342

343-
static std::string name() { return "INT8"; }
343+
static std::string name() { return "MOE_INT8"; }
344344
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
345345
// type_: d for decode, p for prefill
346346
static int recommended_nth_down(int n, char type_ = 'd') {
@@ -833,7 +833,7 @@ struct GemmKernelInt4 {
833833
static inline const int PACK_SIZE_K = 32;
834834
static inline const int PACK_SIZE_M = 8;
835835

836-
static std::string name() { return "INT4"; }
836+
static std::string name() { return "MOE_INT4"; }
837837
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
838838

839839
static int recommended_nth_down(int n, char type_ = 'd') {

0 commit comments

Comments
 (0)