-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4 #1892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,228 @@ | ||
| /** | ||
| * @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions | ||
| * @Author : Claude | ||
| * @Date : 2026-03-18 | ||
| * @Version : 1.0.0 | ||
| * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. | ||
| * | ||
| * Unlike AMX kernels that use packed tile layouts (BufferB with 16x16 transpose), | ||
| * the AVX2 kernel uses row-major storage for all buffers. | ||
| * BufferA/B/C are thin wrappers over raw memory with trivial from_mat/to_mat. | ||
| * | ||
| * GEMM: C[m,n] = sum_k A[m,k] * B[n,k] | ||
| * A: [M, K] row-major BF16 (input activations) | ||
| * B: [N, K] row-major BF16 (weights, each row is one output neuron) | ||
| * C: [M, N] row-major FP32 (output) | ||
| **/ | ||
| #ifndef CPUINFER_OPERATOR_AVX2_BF16_GEMM_H | ||
| #define CPUINFER_OPERATOR_AVX2_BF16_GEMM_H | ||
|
|
||
| #include <immintrin.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <cassert> | ||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <memory> | ||
| #include <tuple> | ||
|
|
||
| #include "avx2_bf16_utils.hpp" | ||
|
|
||
| namespace avx2 { | ||
|
|
||
| // Split range [0, total) among nth threads, return [start, end) for thread ith | ||
| static inline std::pair<int, int> split_range(int total, int ith, int nth) { | ||
| int per = total / nth; | ||
| int rem = total % nth; | ||
| int start = ith * per + std::min(ith, rem); | ||
| int end = start + per + (ith < rem ? 1 : 0); | ||
| return {start, end}; | ||
| } | ||
|
|
||
| struct GemmKernelAVX2BF16 { | ||
| using dt = ggml_bf16_t; | ||
| using output_t = float; | ||
| static constexpr int M_STEP = 1; // No M-direction padding needed (vs AMX 16) | ||
| static constexpr int N_STEP = 8; // 8-wide FP32 AVX2 (vs AMX 32) | ||
| static constexpr int K_STEP = 8; // Process 8 K elements at a time | ||
| static constexpr int N_BLOCK = 64; // N blocking for cache | ||
| static constexpr int K_BLOCK = 256; // K blocking for cache | ||
| static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The static constexpr size_t BF16_BYTE_SIZE = 2;
static constexpr size_t ELEMENT_SIZE = BF16_BYTE_SIZE; |
||
|
|
||
| // No AMX tile configuration needed | ||
| static void config() {} | ||
|
|
||
| // Thread count for N-dimension parallelism | ||
| // Must return >= 1 to avoid division by zero in moe_base task dispatch | ||
| static int recommended_nth(int n) { | ||
| return std::max(1, n / N_STEP); | ||
| } | ||
|
|
||
| // Split N range for multi-threaded GEMM | ||
| static std::pair<int, int> split_range_n(int n, int ith, int nth) { | ||
| return split_range(n, ith, nth); | ||
| } | ||
|
|
||
| // ======================================================================== | ||
| // BufferA: Input activations [M, K] row-major BF16 | ||
| // from_mat() = memcpy (no packing needed for AVX2) | ||
| // ======================================================================== | ||
| struct BufferA { | ||
| ggml_bf16_t* data = nullptr; | ||
| size_t max_m = 0; | ||
| size_t k = 0; | ||
|
|
||
| BufferA() = default; | ||
| BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {} | ||
|
|
||
| static size_t required_size(size_t m, size_t k) { | ||
| return m * k * sizeof(ggml_bf16_t); | ||
| } | ||
|
|
||
| void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; } | ||
|
|
||
| // Copy input rows into buffer (trivial memcpy) | ||
| void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) { | ||
| if (ith == 0 && nth == 1) { | ||
| std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t)); | ||
| } else { | ||
| // Multi-threaded: split by rows | ||
| auto [m_start, m_end] = split_range(m, ith, nth); | ||
| std::memcpy(data + m_start * k, src + m_start * k, | ||
| (size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t)); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| // ======================================================================== | ||
| // BufferB: Weight matrix [N, K] row-major BF16 | ||
| // from_mat() = memcpy (no transpose/packing needed) | ||
| // ======================================================================== | ||
| struct BufferB { | ||
| ggml_bf16_t* b = nullptr; | ||
| size_t n = 0; | ||
| size_t k = 0; | ||
|
|
||
| BufferB() = default; | ||
| BufferB(size_t n_, size_t k_, void* ptr) : n(n_), k(k_), b((ggml_bf16_t*)ptr) {} | ||
|
|
||
| static size_t required_size(size_t n, size_t k) { | ||
| return n * k * sizeof(ggml_bf16_t); | ||
| } | ||
|
|
||
| // Copy weight data (multi-threaded by N dimension) | ||
| void from_mat(const ggml_bf16_t* src, int ith, int nth) { | ||
| auto [n_start, n_end] = split_range((int)n, ith, nth); | ||
| std::memcpy(b + n_start * k, src + n_start * k, | ||
| (size_t)(n_end - n_start) * k * sizeof(ggml_bf16_t)); | ||
| } | ||
| }; | ||
|
|
||
| // ======================================================================== | ||
| // BufferC: Output matrix [M, N] row-major FP32 | ||
| // to_mat() converts FP32 -> BF16 and writes out | ||
| // ======================================================================== | ||
| struct BufferC { | ||
| float* data = nullptr; | ||
| size_t max_m = 0; | ||
| size_t n = 0; | ||
|
|
||
| BufferC() = default; | ||
| BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {} | ||
|
|
||
| static size_t required_size(size_t m, size_t n) { | ||
| return m * n * sizeof(float); | ||
| } | ||
|
|
||
| void set_data(void* ptr) { data = (float*)ptr; } | ||
|
|
||
| // Convert FP32 output to BF16 and write to destination | ||
| void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) { | ||
| auto [n_start, n_end] = split_range_n((int)n, ith, nth); | ||
| for (int mi = 0; mi < m; mi++) { | ||
| float* src_row = data + mi * n; | ||
| ggml_bf16_t* dst_row = dst + mi * n; | ||
| int j = n_start; | ||
| for (; j + 8 <= n_end; j += 8) { | ||
| __m256 v = _mm256_loadu_ps(src_row + j); | ||
| store_fp32_to_bf16(dst_row + j, v); | ||
| } | ||
| // Scalar tail | ||
| for (; j < n_end; j++) { | ||
| dst_row[j] = GGML_FP32_TO_BF16(src_row[j]); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
| }; | ||
|
|
||
| // ============================================================================ | ||
| // AVX2 BF16 GEMM functions | ||
| // C[m,n] = sum_k A[m,k] * B[n,k] | ||
| // ============================================================================ | ||
|
|
||
| // General GEMM (works for both vec_mul m=1 and mat_mul m>1) | ||
| static inline void gemm_bf16( | ||
| int m, int n, int k, | ||
| GemmKernelAVX2BF16::BufferA& a, | ||
| GemmKernelAVX2BF16::BufferB& b, | ||
| GemmKernelAVX2BF16::BufferC& c, | ||
| int ith, int nth) { | ||
|
|
||
| auto [n_start, n_end] = split_range(n, ith, nth); | ||
|
|
||
| for (int ni = n_start; ni < n_end; ni++) { | ||
| const ggml_bf16_t* b_row = b.b + (size_t)ni * k; | ||
|
|
||
| for (int mi = 0; mi < m; mi++) { | ||
| const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k; | ||
|
|
||
| // AVX2 BF16 dot product (matches ggml_vec_dot_bf16 AVX2 path) | ||
| __m256 c1 = _mm256_setzero_ps(); | ||
| __m256 c2 = _mm256_setzero_ps(); | ||
| __m256 c3 = _mm256_setzero_ps(); | ||
| __m256 c4 = _mm256_setzero_ps(); | ||
|
|
||
| int ki = 0; | ||
| for (; ki + 32 <= k; ki += 32) { | ||
| c1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki), load_bf16_to_fp32(b_row + ki), c1); | ||
| c2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 8), load_bf16_to_fp32(b_row + ki + 8), c2); | ||
| c3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 16), load_bf16_to_fp32(b_row + ki + 16), c3); | ||
| c4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 24), load_bf16_to_fp32(b_row + ki + 24), c4); | ||
| } | ||
|
|
||
| float sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(c1, c3), _mm256_add_ps(c2, c4))); | ||
|
|
||
| // Scalar tail | ||
| for (; ki < k; ki++) { | ||
| sum += GGML_BF16_TO_FP32(a_row[ki]) * GGML_BF16_TO_FP32(b_row[ki]); | ||
| } | ||
|
|
||
| c.data[mi * n + ni] = sum; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // vec_mul: dispatch to gemm_bf16 | ||
| static inline void vec_mul( | ||
| int m, int n, int k, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c, | ||
| int ith, int nth) { | ||
| gemm_bf16(m, n, k, *a, *b, *c, ith, nth); | ||
| } | ||
|
|
||
| // mat_mul: dispatch to gemm_bf16 | ||
| static inline void mat_mul( | ||
| int m, int n, int k, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b, | ||
| std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c, | ||
| int ith, int nth) { | ||
| gemm_bf16(m, n, k, *a, *b, *c, ith, nth); | ||
| } | ||
|
|
||
| } // namespace avx2 | ||
|
|
||
| #endif // CPUINFER_OPERATOR_AVX2_BF16_GEMM_H | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,132 @@ | ||||||||||||||
| /** | ||||||||||||||
| * @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation) | ||||||||||||||
| * @Author : Claude | ||||||||||||||
| * @Date : 2026-03-18 | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||
| * @Version : 1.0.0 | ||||||||||||||
| * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. | ||||||||||||||
| * | ||||||||||||||
| * AVX2 ports of the AVX512 utilities in amx/la/utils.hpp and amx/la/amx.hpp. | ||||||||||||||
| * Uses 256-bit SIMD (8 floats) instead of 512-bit (16 floats). | ||||||||||||||
| **/ | ||||||||||||||
| #ifndef CPUINFER_OPERATOR_AVX2_BF16_UTILS_H | ||||||||||||||
| #define CPUINFER_OPERATOR_AVX2_BF16_UTILS_H | ||||||||||||||
|
|
||||||||||||||
| #include <immintrin.h> | ||||||||||||||
| #include <cmath> | ||||||||||||||
| #include "llama.cpp/ggml.h" | ||||||||||||||
|
|
||||||||||||||
| namespace avx2 { | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // BF16 <-> FP32 conversion | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| // Load 8 BF16 values and convert to 8 FP32 values | ||||||||||||||
| // BF16 is the upper 16 bits of FP32, so shift left by 16 | ||||||||||||||
| static inline __m256 load_bf16_to_fp32(const ggml_bf16_t* src) { | ||||||||||||||
| __m128i bf16 = _mm_loadu_si128((const __m128i*)src); | ||||||||||||||
| __m256i i32 = _mm256_cvtepu16_epi32(bf16); | ||||||||||||||
| return _mm256_castsi256_ps(_mm256_slli_epi32(i32, 16)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Convert 8 FP32 values to 8 BF16 values with round-to-nearest-even | ||||||||||||||
| // Matches ggml_compute_fp32_to_bf16 semantics (ggml-impl.h:87) | ||||||||||||||
| // and amx/la/utils.hpp:24 tie-bit correction | ||||||||||||||
| static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) { | ||||||||||||||
| __m256i i32 = _mm256_castps_si256(src); | ||||||||||||||
| // Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1) | ||||||||||||||
| __m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1)); | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The literal
Suggested change
|
||||||||||||||
| __m256i round = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), tie_bit); | ||||||||||||||
| __m256i rounded = _mm256_add_epi32(i32, round); | ||||||||||||||
| __m256i shifted = _mm256_srli_epi32(rounded, 16); | ||||||||||||||
| // Pack 32-bit -> 16-bit | ||||||||||||||
| // _mm_packus_epi32 processes 128-bit lanes: packs [lo0..lo3, hi0..hi3] -> [lo0..lo3, hi0..hi3] | ||||||||||||||
| __m128i lo = _mm256_castsi256_si128(shifted); | ||||||||||||||
| __m128i hi = _mm256_extracti128_si256(shifted, 1); | ||||||||||||||
| __m128i packed = _mm_packus_epi32(lo, hi); | ||||||||||||||
| _mm_storeu_si128((__m128i*)dst, packed); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Load 16 BF16 -> 2x8 FP32 (corresponds to avx512_32xbf16_to_32xfp32) | ||||||||||||||
| static inline void load_16xbf16_to_2x8xfp32(const ggml_bf16_t* src, __m256* out0, __m256* out1) { | ||||||||||||||
| *out0 = load_bf16_to_fp32(src); | ||||||||||||||
| *out1 = load_bf16_to_fp32(src + 8); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Store 2x8 FP32 -> 16 BF16 (corresponds to avx512_32xfp32_to_32xbf16) | ||||||||||||||
| static inline void store_2x8xfp32_to_16xbf16(__m256* in0, __m256* in1, ggml_bf16_t* dst) { | ||||||||||||||
| store_fp32_to_bf16(dst, *in0); | ||||||||||||||
| store_fp32_to_bf16(dst + 8, *in1); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // Horizontal sum for __m256 (8 floats -> 1 float) | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| static inline float hsum_avx2(__m256 v) { | ||||||||||||||
| __m128 hi = _mm256_extractf128_ps(v, 1); | ||||||||||||||
| __m128 lo = _mm256_castps256_ps128(v); | ||||||||||||||
| __m128 sum = _mm_add_ps(lo, hi); | ||||||||||||||
| sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum)); | ||||||||||||||
| sum = _mm_add_ss(sum, _mm_movehdup_ps(sum)); | ||||||||||||||
| return _mm_cvtss_f32(sum); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // Fast exp approximation (AVX2 port of amx::exp_avx512) | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| static inline __m256 exp_avx2(__m256 x) { | ||||||||||||||
| const __m256 log2e = _mm256_set1_ps(1.44269504089f); | ||||||||||||||
|
|
||||||||||||||
| __m256 y = _mm256_mul_ps(x, log2e); | ||||||||||||||
| __m256i int_part = _mm256_cvtps_epi32(y); | ||||||||||||||
| __m256 frac_part = _mm256_sub_ps(y, _mm256_cvtepi32_ps(int_part)); | ||||||||||||||
|
|
||||||||||||||
| const __m256 poly_1 = _mm256_set1_ps(0.9999999995f); | ||||||||||||||
| const __m256 poly_2 = _mm256_set1_ps(0.6931471805f); | ||||||||||||||
| const __m256 poly_3 = _mm256_set1_ps(0.2402265069f); | ||||||||||||||
| const __m256 poly_4 = _mm256_set1_ps(0.0555041087f); | ||||||||||||||
| const __m256 poly_5 = _mm256_set1_ps(0.0096181291f); | ||||||||||||||
| const __m256 poly_6 = _mm256_set1_ps(0.0013333558f); | ||||||||||||||
|
|
||||||||||||||
| __m256 frac_exp = _mm256_fmadd_ps( | ||||||||||||||
| _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4), | ||||||||||||||
| frac_part, poly_3), | ||||||||||||||
| frac_part, poly_2), | ||||||||||||||
| frac_part, poly_1); | ||||||||||||||
|
|
||||||||||||||
| // 2^int_part: AVX2 doesn't have _mm256_scalef_ps, use manual construction | ||||||||||||||
| // 2^n = reinterpret((n + 127) << 23) for float | ||||||||||||||
| // Clamp int_part to [-126, 127] to avoid invalid bit patterns: | ||||||||||||||
| // int_part < -126 → biased < 1 → denorm/zero (scalef_ps would give 0) | ||||||||||||||
| // int_part > 127 → biased > 254 → inf (scalef_ps would give inf) | ||||||||||||||
| __m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)), | ||||||||||||||
| _mm256_set1_epi32(-126)); | ||||||||||||||
|
Comment on lines
+104
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The literal values
Suggested change
|
||||||||||||||
| __m256i biased = _mm256_add_epi32(clamped, _mm256_set1_epi32(127)); | ||||||||||||||
| __m256i shifted = _mm256_slli_epi32(biased, 23); | ||||||||||||||
| __m256 two_pow_i = _mm256_castsi256_ps(shifted); | ||||||||||||||
|
|
||||||||||||||
| return _mm256_mul_ps(two_pow_i, frac_exp); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // SiLU activation: silu(gate) * up = gate * sigmoid(gate) * up | ||||||||||||||
| // AVX2 port of amx::act_fn | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| static inline __m256 act_fn(__m256 gate_val, __m256 up_val) { | ||||||||||||||
| __m256 neg_gate_val = _mm256_sub_ps(_mm256_setzero_ps(), gate_val); | ||||||||||||||
| // Clamp to avoid exp overflow | ||||||||||||||
| const __m256 max_exp_input = _mm256_set1_ps(88.0f); | ||||||||||||||
| neg_gate_val = _mm256_min_ps(neg_gate_val, max_exp_input); | ||||||||||||||
| __m256 exp_neg_gate = exp_avx2(neg_gate_val); | ||||||||||||||
| __m256 denom = _mm256_add_ps(_mm256_set1_ps(1.0f), exp_neg_gate); | ||||||||||||||
| __m256 act_val = _mm256_div_ps(gate_val, denom); | ||||||||||||||
|
|
||||||||||||||
| return _mm256_mul_ps(act_val, up_val); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| } // namespace avx2 | ||||||||||||||
|
|
||||||||||||||
| #endif // CPUINFER_OPERATOR_AVX2_BF16_UTILS_H | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
@Datein the file header is set to a future date. Please update this to the current date or remove it if it's not meant to be dynamic.