Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions kt-kernel/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ static const bool _is_plain_ = false;
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#endif
// AVX2 backends — always available on x86_64 (no AMX/AVX512 dependency)
#if defined(__x86_64__)
#include "operators/avx2/bf16-moe.hpp"
#include "operators/avx2/fp8-moe.hpp"
#include "operators/avx2/gptq_int4-moe.hpp"
#endif

#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions

#include <cstdint>
Expand Down Expand Up @@ -578,6 +585,13 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
#endif
#endif
// AVX2 backends — available on all x86_64 (no AMX/AVX512 requirement)
#if defined(__x86_64__)
bind_moe_module<AVX2_BF16_MOE_TP<avx2::GemmKernelAVX2BF16>>(moe_module, "AVX2BF16_MOE");
bind_moe_module<AVX2_FP8_MOE_TP<avx2::GemmKernelAVX2FP8>>(moe_module, "AVX2FP8_MOE");
bind_moe_module<AVX2_GPTQ_INT4_MOE_TP<avx2::GemmKernelAVX2GPTQInt4>>(moe_module, "AVX2GPTQInt4_MOE");
#endif

#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
#if defined(__aarch64__) && defined(CPU_USE_KML)
Expand Down
228 changes: 228 additions & 0 deletions kt-kernel/operators/avx2/avx2_bf16_gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/**
* @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions
* @Author : Claude
* @Date : 2026-03-18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @Date in the file header is set to a future date. Please update this to the current date or remove it if it's not meant to be dynamic.

* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* Unlike AMX kernels that use packed tile layouts (BufferB with 16x16 transpose),
* the AVX2 kernel uses row-major storage for all buffers.
* BufferA/B/C are thin wrappers over raw memory with trivial from_mat/to_mat.
*
* GEMM: C[m,n] = sum_k A[m,k] * B[n,k]
* A: [M, K] row-major BF16 (input activations)
* B: [N, K] row-major BF16 (weights, each row is one output neuron)
* C: [M, N] row-major FP32 (output)
**/
#ifndef CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
#define CPUINFER_OPERATOR_AVX2_BF16_GEMM_H

#include <immintrin.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <memory>
#include <tuple>

#include "avx2_bf16_utils.hpp"

namespace avx2 {

// Split range [0, total) among nth threads, return [start, end) for thread ith
static inline std::pair<int, int> split_range(int total, int ith, int nth) {
int per = total / nth;
int rem = total % nth;
int start = ith * per + std::min(ith, rem);
int end = start + per + (ith < rem ? 1 : 0);
return {start, end};
}

struct GemmKernelAVX2BF16 {
using dt = ggml_bf16_t;
using output_t = float;
static constexpr int M_STEP = 1; // No M-direction padding needed (vs AMX 16)
static constexpr int N_STEP = 8; // 8-wide FP32 AVX2 (vs AMX 32)
static constexpr int K_STEP = 8; // Process 8 K elements at a time
static constexpr int N_BLOCK = 64; // N blocking for cache
static constexpr int K_BLOCK = 256; // K blocking for cache
static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ELEMENT_SIZE is defined as a double but represents a byte count. It would be more appropriate to use an int or size_t type for this constant, and potentially define BF16_BYTE_SIZE as a named constant.

  static constexpr size_t BF16_BYTE_SIZE = 2;
  static constexpr size_t ELEMENT_SIZE = BF16_BYTE_SIZE;


// No AMX tile configuration needed
static void config() {}

// Thread count for N-dimension parallelism
// Must return >= 1 to avoid division by zero in moe_base task dispatch
static int recommended_nth(int n) {
return std::max(1, n / N_STEP);
}

// Split N range for multi-threaded GEMM
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
return split_range(n, ith, nth);
}

// ========================================================================
// BufferA: Input activations [M, K] row-major BF16
// from_mat() = memcpy (no packing needed for AVX2)
// ========================================================================
struct BufferA {
ggml_bf16_t* data = nullptr;
size_t max_m = 0;
size_t k = 0;

BufferA() = default;
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}

static size_t required_size(size_t m, size_t k) {
return m * k * sizeof(ggml_bf16_t);
}

void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }

// Copy input rows into buffer (trivial memcpy)
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
if (ith == 0 && nth == 1) {
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
} else {
// Multi-threaded: split by rows
auto [m_start, m_end] = split_range(m, ith, nth);
std::memcpy(data + m_start * k, src + m_start * k,
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
}
}
};

// ========================================================================
// BufferB: Weight matrix [N, K] row-major BF16
// from_mat() = memcpy (no transpose/packing needed)
// ========================================================================
struct BufferB {
ggml_bf16_t* b = nullptr;
size_t n = 0;
size_t k = 0;

BufferB() = default;
BufferB(size_t n_, size_t k_, void* ptr) : n(n_), k(k_), b((ggml_bf16_t*)ptr) {}

static size_t required_size(size_t n, size_t k) {
return n * k * sizeof(ggml_bf16_t);
}

// Copy weight data (multi-threaded by N dimension)
void from_mat(const ggml_bf16_t* src, int ith, int nth) {
auto [n_start, n_end] = split_range((int)n, ith, nth);
std::memcpy(b + n_start * k, src + n_start * k,
(size_t)(n_end - n_start) * k * sizeof(ggml_bf16_t));
}
};

// ========================================================================
// BufferC: Output matrix [M, N] row-major FP32
// to_mat() converts FP32 -> BF16 and writes out
// ========================================================================
struct BufferC {
float* data = nullptr;
size_t max_m = 0;
size_t n = 0;

BufferC() = default;
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}

static size_t required_size(size_t m, size_t n) {
return m * n * sizeof(float);
}

void set_data(void* ptr) { data = (float*)ptr; }

// Convert FP32 output to BF16 and write to destination
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
auto [n_start, n_end] = split_range_n((int)n, ith, nth);
for (int mi = 0; mi < m; mi++) {
float* src_row = data + mi * n;
ggml_bf16_t* dst_row = dst + mi * n;
int j = n_start;
for (; j + 8 <= n_end; j += 8) {
__m256 v = _mm256_loadu_ps(src_row + j);
store_fp32_to_bf16(dst_row + j, v);
}
// Scalar tail
for (; j < n_end; j++) {
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
}
}
}
};
};

// ============================================================================
// AVX2 BF16 GEMM functions
// C[m,n] = sum_k A[m,k] * B[n,k]
// ============================================================================

// General GEMM (works for both vec_mul m=1 and mat_mul m>1)
static inline void gemm_bf16(
int m, int n, int k,
GemmKernelAVX2BF16::BufferA& a,
GemmKernelAVX2BF16::BufferB& b,
GemmKernelAVX2BF16::BufferC& c,
int ith, int nth) {

auto [n_start, n_end] = split_range(n, ith, nth);

for (int ni = n_start; ni < n_end; ni++) {
const ggml_bf16_t* b_row = b.b + (size_t)ni * k;

for (int mi = 0; mi < m; mi++) {
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;

// AVX2 BF16 dot product (matches ggml_vec_dot_bf16 AVX2 path)
__m256 c1 = _mm256_setzero_ps();
__m256 c2 = _mm256_setzero_ps();
__m256 c3 = _mm256_setzero_ps();
__m256 c4 = _mm256_setzero_ps();

int ki = 0;
for (; ki + 32 <= k; ki += 32) {
c1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki), load_bf16_to_fp32(b_row + ki), c1);
c2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 8), load_bf16_to_fp32(b_row + ki + 8), c2);
c3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 16), load_bf16_to_fp32(b_row + ki + 16), c3);
c4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 24), load_bf16_to_fp32(b_row + ki + 24), c4);
}

float sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(c1, c3), _mm256_add_ps(c2, c4)));

// Scalar tail
for (; ki < k; ki++) {
sum += GGML_BF16_TO_FP32(a_row[ki]) * GGML_BF16_TO_FP32(b_row[ki]);
}

c.data[mi * n + ni] = sum;
}
}
}

// vec_mul: dispatch to gemm_bf16
static inline void vec_mul(
int m, int n, int k,
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
int ith, int nth) {
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
}

// mat_mul: dispatch to gemm_bf16
static inline void mat_mul(
int m, int n, int k,
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
int ith, int nth) {
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
}

} // namespace avx2

#endif // CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
132 changes: 132 additions & 0 deletions kt-kernel/operators/avx2/avx2_bf16_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/**
* @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation)
* @Author : Claude
* @Date : 2026-03-18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @Date in the file header is set to a future date. Please update this to the current date or remove it if it's not meant to be dynamic.

* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* AVX2 ports of the AVX512 utilities in amx/la/utils.hpp and amx/la/amx.hpp.
* Uses 256-bit SIMD (8 floats) instead of 512-bit (16 floats).
**/
#ifndef CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
#define CPUINFER_OPERATOR_AVX2_BF16_UTILS_H

#include <immintrin.h>
#include <cmath>
#include "llama.cpp/ggml.h"

namespace avx2 {

// ============================================================================
// BF16 <-> FP32 conversion
// ============================================================================

// Load 8 BF16 values and convert to 8 FP32 values
// BF16 is the upper 16 bits of FP32, so shift left by 16
static inline __m256 load_bf16_to_fp32(const ggml_bf16_t* src) {
__m128i bf16 = _mm_loadu_si128((const __m128i*)src);
__m256i i32 = _mm256_cvtepu16_epi32(bf16);
return _mm256_castsi256_ps(_mm256_slli_epi32(i32, 16));
}

// Convert 8 FP32 values to 8 BF16 values with round-to-nearest-even
// Matches ggml_compute_fp32_to_bf16 semantics (ggml-impl.h:87)
// and amx/la/utils.hpp:24 tie-bit correction
static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) {
__m256i i32 = _mm256_castps_si256(src);
// Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1)
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal 0x7FFF is a magic number. Please define it as a named constant (e.g., BF16_ROUND_MAGIC) to improve readability and maintainability.

Suggested change
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
const __m256i BF16_ROUND_MAGIC = _mm256_set1_epi32(0x7FFF);
__m256i round = _mm256_add_epi32(BF16_ROUND_MAGIC, tie_bit);

__m256i round = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), tie_bit);
__m256i rounded = _mm256_add_epi32(i32, round);
__m256i shifted = _mm256_srli_epi32(rounded, 16);
// Pack 32-bit -> 16-bit
// _mm_packus_epi32 processes 128-bit lanes: packs [lo0..lo3, hi0..hi3] -> [lo0..lo3, hi0..hi3]
__m128i lo = _mm256_castsi256_si128(shifted);
__m128i hi = _mm256_extracti128_si256(shifted, 1);
__m128i packed = _mm_packus_epi32(lo, hi);
_mm_storeu_si128((__m128i*)dst, packed);
}

// Load 16 BF16 -> 2x8 FP32 (corresponds to avx512_32xbf16_to_32xfp32)
static inline void load_16xbf16_to_2x8xfp32(const ggml_bf16_t* src, __m256* out0, __m256* out1) {
*out0 = load_bf16_to_fp32(src);
*out1 = load_bf16_to_fp32(src + 8);
}

// Store 2x8 FP32 -> 16 BF16 (corresponds to avx512_32xfp32_to_32xbf16)
static inline void store_2x8xfp32_to_16xbf16(__m256* in0, __m256* in1, ggml_bf16_t* dst) {
store_fp32_to_bf16(dst, *in0);
store_fp32_to_bf16(dst + 8, *in1);
}

// ============================================================================
// Horizontal sum for __m256 (8 floats -> 1 float)
// ============================================================================

static inline float hsum_avx2(__m256 v) {
__m128 hi = _mm256_extractf128_ps(v, 1);
__m128 lo = _mm256_castps256_ps128(v);
__m128 sum = _mm_add_ps(lo, hi);
sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
sum = _mm_add_ss(sum, _mm_movehdup_ps(sum));
return _mm_cvtss_f32(sum);
}

// ============================================================================
// Fast exp approximation (AVX2 port of amx::exp_avx512)
// ============================================================================

static inline __m256 exp_avx2(__m256 x) {
const __m256 log2e = _mm256_set1_ps(1.44269504089f);

__m256 y = _mm256_mul_ps(x, log2e);
__m256i int_part = _mm256_cvtps_epi32(y);
__m256 frac_part = _mm256_sub_ps(y, _mm256_cvtepi32_ps(int_part));

const __m256 poly_1 = _mm256_set1_ps(0.9999999995f);
const __m256 poly_2 = _mm256_set1_ps(0.6931471805f);
const __m256 poly_3 = _mm256_set1_ps(0.2402265069f);
const __m256 poly_4 = _mm256_set1_ps(0.0555041087f);
const __m256 poly_5 = _mm256_set1_ps(0.0096181291f);
const __m256 poly_6 = _mm256_set1_ps(0.0013333558f);

__m256 frac_exp = _mm256_fmadd_ps(
_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4),
frac_part, poly_3),
frac_part, poly_2),
frac_part, poly_1);

// 2^int_part: AVX2 doesn't have _mm256_scalef_ps, use manual construction
// 2^n = reinterpret((n + 127) << 23) for float
// Clamp int_part to [-126, 127] to avoid invalid bit patterns:
// int_part < -126 → biased < 1 → denorm/zero (scalef_ps would give 0)
// int_part > 127 → biased > 254 → inf (scalef_ps would give inf)
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
_mm256_set1_epi32(-126));
Comment on lines +104 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The literal values 127 and -126 are magic numbers. Please define them as named constants (e.g., MAX_EXP_CLAMP and MIN_EXP_CLAMP) to improve readability and maintainability.

Suggested change
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
_mm256_set1_epi32(-126));
const __m256i MAX_EXP_CLAMP = _mm256_set1_epi32(127);
const __m256i MIN_EXP_CLAMP = _mm256_set1_epi32(-126);
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, MAX_EXP_CLAMP),
MIN_EXP_CLAMP);

__m256i biased = _mm256_add_epi32(clamped, _mm256_set1_epi32(127));
__m256i shifted = _mm256_slli_epi32(biased, 23);
__m256 two_pow_i = _mm256_castsi256_ps(shifted);

return _mm256_mul_ps(two_pow_i, frac_exp);
}

// ============================================================================
// SiLU activation: silu(gate) * up = gate * sigmoid(gate) * up
// AVX2 port of amx::act_fn
// ============================================================================

static inline __m256 act_fn(__m256 gate_val, __m256 up_val) {
__m256 neg_gate_val = _mm256_sub_ps(_mm256_setzero_ps(), gate_val);
// Clamp to avoid exp overflow
const __m256 max_exp_input = _mm256_set1_ps(88.0f);
neg_gate_val = _mm256_min_ps(neg_gate_val, max_exp_input);
__m256 exp_neg_gate = exp_avx2(neg_gate_val);
__m256 denom = _mm256_add_ps(_mm256_set1_ps(1.0f), exp_neg_gate);
__m256 act_val = _mm256_div_ps(gate_val, denom);

return _mm256_mul_ps(act_val, up_val);
}

} // namespace avx2

#endif // CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
Loading
Loading