diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d948b00cc7f30..ad24f341bdd5d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -529,7 +529,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, - + GGML_OP_SPARSEK_ATTN, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -2231,6 +2231,26 @@ extern "C" { // n_head % ne32 == 0 // ne3 % ne33 == 0 // + + GGML_API struct ggml_tensor * ggml_sparsek_attn( + struct ggml_context * ctx, + struct ggml_tensor * Q, + struct ggml_tensor * K, + struct ggml_tensor * V, + int32_t k_top, + int32_t win_local, + int32_t stride_global); + + GGML_API void ggml_sparsek_attn_set_params( + struct ggml_tensor * a, + int32_t k_top, + int32_t win_local, + int32_t stride_global); + + GGML_API int32_t ggml_sparsek_attn_get_param( + const struct ggml_tensor * a, + int index); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9ec485cfa2ff7..3fa954e1c324a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1952,6 +1952,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_flash_attn_ext(params, tensor); } break; + case GGML_OP_SPARSEK_ATTN: + { + ggml_compute_forward_sparsek_attn(params, tensor); + } break; case GGML_OP_FLASH_ATTN_BACK: { int32_t t = ggml_get_op_params_i32(tensor, 0); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3156bd60101d7..762d340761d03 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,6 +9,7 @@ #include #include +#include // ggml_compute_forward_dup @@ -7907,6 +7908,194 @@ void ggml_compute_forward_argsort( } } +//------------------------------------------------------------------------------ +// SparseK Attention (CPU, final optimized version) +//------------------------------------------------------------------------------ +// +// Implements SparseK Attention as a GGML operator for the CPU backend. +// Features: +// • Top-K filtering using nth_element (O(N)) +// • Optional local window (win_local) +// • Optional global stride (stride_glb) +// • Numerically stable softmax +// • Preallocated buffers for performance +// +// Author: Yael Shuker & Gitty Burstein +//------------------------------------------------------------------------------ + +#include +#include +#include +#include + +static void ggml_compute_forward_sparsek_attn_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + // Single-threaded baseline version + if (params->ith != 0) return; + + const struct ggml_tensor * Q = dst->src[0]; + const struct ggml_tensor * K = dst->src[1]; + const struct ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(Q && K && V); + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + // Operator parameters + const int32_t k_top = ggml_get_op_params_i32(dst, 0); + const int32_t win_local = ggml_get_op_params_i32(dst, 1); // -1 ⇒ no local window + const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); // ≤1 ⇒ no global stride + + const bool use_local = (win_local >= 0); + const bool use_stride = (stride_glb > 1); + + // GGML tensor dimensions: ne[0]=D, ne[1]=T, ne[2]=H, ne[3]=B + const int64_t D = Q->ne[0]; + const int64_t T = Q->ne[1]; + const int64_t H = Q->ne[2]; + const int64_t B = Q->ne[3]; + + // Dimension validation + GGML_ASSERT(K->ne[0] == D && V->ne[0] == D); + GGML_ASSERT(K->ne[1] == T && V->ne[1] == T); + GGML_ASSERT(K->ne[2] == H && V->ne[2] == H); + GGML_ASSERT(K->ne[3] == B && V->ne[3] == B); + + // Parameter sanity checks + GGML_ASSERT(k_top >= 0 && k_top <= (int32_t)T); + GGML_ASSERT(win_local >= -1); + GGML_ASSERT(stride_glb >= 0); + + const float scale = 1.0f / sqrtf((float)D); + const float NINF = -std::numeric_limits::infinity(); + + // Preallocated buffers to avoid heap churn + std::vector attn_row((size_t)T, NINF); + std::vector cand_idx; cand_idx.reserve((size_t)T); + std::vector scores; scores.reserve((size_t)T); + + for (int64_t b = 0; b < B; ++b) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t iq = 0; iq < T; ++iq) { + + // (0) Build candidate index list (always include self) + cand_idx.clear(); + scores.clear(); + + if (!use_local && !use_stride) { + // No sparsity: attend to all tokens + for (int64_t j = 0; j < T; ++j) + cand_idx.push_back((int32_t)j); + } else { + // Apply local window and/or global stride + for (int64_t j = 0; j < T; ++j) { + const int64_t dist = iq >= j ? iq - j : j - iq; + const bool pass_local = use_local && (dist <= (int64_t)win_local); + const bool pass_stride = use_stride && (stride_glb > 0 && j % stride_glb == 0); + if (pass_local || pass_stride || j == iq) + cand_idx.push_back((int32_t)j); + } + } + + // Edge case: no candidates or k_top==0 → output zeros + if (k_top == 0 || cand_idx.empty()) { + float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]); + std::fill(y0, y0 + D, 0.0f); + continue; + } + + // (1) Compute scaled dot-product Q·K only for candidates + std::fill(attn_row.begin(), attn_row.end(), NINF); + const float * qv = (const float *)((const char *)Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1]); + + for (int32_t j : cand_idx) { + const float * kv = (const float *)((const char *)K->data + b*K->nb[3] + h*K->nb[2] + (int64_t)j*K->nb[1]); + float dot = 0.0f; + for (int64_t d = 0; d < D; ++d) + dot += qv[d] * kv[d]; + attn_row[j] = dot * scale; + } + + // (2) Determine true Top-K threshold using nth_element + const int num_candidates = (int)cand_idx.size(); + const int kk = std::min(std::max(1, k_top), num_candidates); + + if (kk < num_candidates) { + scores.resize((size_t)num_candidates); + for (size_t i = 0; i < cand_idx.size(); ++i) + scores[i] = attn_row[cand_idx[i]]; + + std::nth_element(scores.begin(), scores.begin() + (kk - 1), scores.end(), std::greater()); + const float thr = scores[kk - 1]; + + // Mask all values below the threshold + for (int32_t j : cand_idx) + if (attn_row[j] < thr) attn_row[j] = NINF; + } + + // (3) Numerically stable softmax + float maxv = NINF; + for (int32_t j : cand_idx) + maxv = std::max(maxv, attn_row[j]); + + // Handle all-masked rows + if (!std::isfinite(maxv)) { + float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]); + std::fill(y0, y0 + D, 0.0f); + continue; + } + + float sum = 0.0f; + for (int32_t j : cand_idx) { + if (attn_row[j] == NINF) continue; + const float e = expf(attn_row[j] - maxv); + attn_row[j] = e; + sum += e; + } + + const float inv_sum = (sum > 0.0f) ? (1.0f / sum) : 0.0f; + for (int32_t j : cand_idx) { + if (attn_row[j] == NINF) continue; + attn_row[j] *= inv_sum; + } + + // (4) Compute output y = A·V + float * y = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]); + for (int64_t d = 0; d < D; ++d) { + float acc = 0.0f; + for (int32_t j : cand_idx) { + const float aij = attn_row[j]; + if (!(aij > 0.0f)) continue; // skip zero or masked + const float * vv = (const float *)((const char *)V->data + b*V->nb[3] + h*V->nb[2] + (int64_t)j*V->nb[1]); + acc += aij * vv[d]; + } + y[d] = acc; + } + } + } + } + + GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n", + k_top, win_local, stride_glb); +} + +void ggml_compute_forward_sparsek_attn( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F32: + ggml_compute_forward_sparsek_attn_f32(params, dst); + break; + default: + GGML_ASSERT(false && "sparsek_attn: unsupported dst type"); + } +} + + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b45833..e43b23a5587bd 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -86,6 +86,8 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst); + void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, const bool masked, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9be35c1be8456..9ad055c994672 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", - + "SPARSEK_ATTN", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", "SSM_CONV", @@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", - + "sparsek_attn(x)", "flash_attn_ext(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5063,6 +5063,52 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_sparsek_attn +struct ggml_tensor * ggml_sparsek_attn( + struct ggml_context * ctx, + struct ggml_tensor * Q, + struct ggml_tensor * K, + struct ggml_tensor * V, + int32_t k_top, + int32_t win_local, + int32_t stride_global) { + + GGML_ASSERT(ggml_can_mul_mat(K, Q)); + GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]); + + int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + + int32_t params_i32[3] = { k_top, win_local, stride_global }; + ggml_set_op_params(result, params_i32, sizeof(params_i32)); + + result->op = GGML_OP_SPARSEK_ATTN; + result->src[0] = Q; + result->src[1] = K; + result->src[2] = V; + + return result; +} + + +void ggml_sparsek_attn_set_params(struct ggml_tensor * a, + int32_t k_top, + int32_t win_local, + int32_t stride_global) { + GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN); + ggml_set_op_params_i32(a, 0, k_top); + ggml_set_op_params_i32(a, 1, win_local); + ggml_set_op_params_i32(a, 2, stride_global); +} + +int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) { + GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN); + return ggml_get_op_params_i32(a, index); +} + + + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aee1730137900..5350ea13e6ee6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1778,6 +1778,7 @@ struct test_example : public test_case { }; + // GGML_OP_UNARY struct test_unary : public test_case { const ggml_unary_op op; @@ -5362,7 +5363,46 @@ struct test_leaky_relu : public test_case { } }; -// GGML_OP_FLASH_ATTN_EXT +// GGML_OP_SPARSEK_ATTN +struct test_sparsek_attn : public test_case { + const int64_t d_qk; + const int64_t d_v; + const int64_t n_head; + const int64_t n_tokens; + const int64_t batch; + const int32_t k_top; + const int32_t win_local; + const int32_t stride_global; + + std::string vars() override { + return VARS_TO_STR9(d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0); + } + + test_sparsek_attn(int64_t d_qk = 128, int64_t d_v = 128, int64_t n_head = 8, + int64_t n_tokens = 256, int64_t batch = 4, + int32_t k_top = 32, int32_t win_local = 64, int32_t stride_global = 128) + : d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch), + k_top(k_top), win_local(win_local), stride_global(stride_global) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int64_t n_q = n_tokens; + ggml_tensor * Q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch); + ggml_set_name(Q, "Q"); + ggml_tensor * K = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch); + ggml_set_name(K, "K"); + ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch); + ggml_set_name(V, "V"); + + ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global); + ggml_set_name(out, "SPARSEK_ATTN_out"); + + return out; + } +}; + + + +// GGML_OP_FLAsH_ATTN_EXT struct test_flash_attn_ext : public test_case { const int64_t hsk; // K head size const int64_t hsv; // V head size @@ -7134,6 +7174,23 @@ static std::vector> make_test_cases_eval() { } } } + // ---- SPARSEK_ATTN -------------------------------------------------- + for (int64_t d_qk : {64, 128}) { + for (int64_t d_v : {64, 128}) { + for (int64_t n_head : {4, 8}) { + for (int64_t kv : {113, 512}) { + for (int64_t b : {1, 4}) { + for (int32_t k_top : {16, 32}) { + for (int32_t win_local : {32, 64}) { + test_cases.emplace_back(new test_sparsek_attn( + d_qk, d_v, n_head, kv, b, k_top, win_local, /*stride_global*/128)); + } + } + } + } + } + } + } test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3})); test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1})); @@ -7194,16 +7251,15 @@ static std::vector> make_test_cases_eval() { // Test cases for performance evaluation: should be representative of real-world use cases static std::vector> make_test_cases_perf() { std::vector> test_cases; - // Conv2d: K=CRS=NPQ=4096 matmul performance - uint32_t iwh_idx = 0; - uint32_t kwh_idx = 1; - uint32_t Cout_idx = 2; - uint32_t Cin_idx = 3; - uint32_t B_idx = 4; - std::vector> cases = { - //{IWH, KWH, Cout, Cin, B} - // K=CRS=NPQ=4096 conv2d matmul performance + uint32_t iwh_idx = 0; + uint32_t kwh_idx = 1; + uint32_t Cout_idx = 2; + uint32_t Cin_idx = 3; + uint32_t B_idx = 4; + std::vector> cases = { +// {IWH, KWH, Cout, Cin, B} +// K=CRS=NPQ=4096 conv2d matmul performance {19, 4, 4096, 256, 16}, // K=128, CRS=128, NPQ=4096 { 19, 4, 128, 8, 16},