Skip to content

Commit 49c7e4b

Browse files
committed
Implement final optimized SparseK Attention (CPU) Co-authored-by: Yael <[email protected]>
Co-authored-by: Gitty <[email protected]>
1 parent b19c244 commit 49c7e4b

File tree

1 file changed

+110
-42
lines changed

1 file changed

+110
-42
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 110 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7909,14 +7909,30 @@ void ggml_compute_forward_argsort(
79097909
}
79107910

79117911
//------------------------------------------------------------------------------
7912-
// SparseK Attention (CPU)
7912+
// SparseK Attention (CPU, final optimized version)
79137913
//------------------------------------------------------------------------------
7914+
//
7915+
// Implements SparseK Attention as a GGML operator for the CPU backend.
7916+
// Features:
7917+
// • Top-K filtering using nth_element (O(N))
7918+
// • Optional local window (win_local)
7919+
// • Optional global stride (stride_glb)
7920+
// • Numerically stable softmax
7921+
// • Preallocated buffers for performance
7922+
//
7923+
// Author: Yael Shuker (yael-works)
7924+
//------------------------------------------------------------------------------
7925+
7926+
#include <algorithm>
7927+
#include <vector>
7928+
#include <cmath>
7929+
#include <limits>
79147930

79157931
static void ggml_compute_forward_sparsek_attn_f32(
79167932
const struct ggml_compute_params * params,
79177933
struct ggml_tensor * dst) {
79187934

7919-
// Single-threaded baseline version
7935+
// Single-threaded baseline version
79207936
if (params->ith != 0) return;
79217937

79227938
const struct ggml_tensor * Q = dst->src[0];
@@ -7929,80 +7945,132 @@ static void ggml_compute_forward_sparsek_attn_f32(
79297945
GGML_ASSERT(V->type == GGML_TYPE_F32);
79307946
GGML_ASSERT(dst->type == GGML_TYPE_F32);
79317947

7948+
// Operator parameters
79327949
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
7933-
const int32_t win_local = ggml_get_op_params_i32(dst, 1);
7934-
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2);
7935-
GGML_UNUSED(win_local);
7936-
GGML_UNUSED(stride_glb);
7950+
const int32_t win_local = ggml_get_op_params_i32(dst, 1); // -1 ⇒ no local window
7951+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); // ≤1 ⇒ no global stride
7952+
7953+
const bool use_local = (win_local >= 0);
7954+
const bool use_stride = (stride_glb > 1);
79377955

7938-
// Tensor dimensions according to GGML layout: ne[0]=d, ne[1]=seq, ne[2]=head, ne[3]=batch
7956+
// GGML tensor dimensions: ne[0]=D, ne[1]=T, ne[2]=H, ne[3]=B
79397957
const int64_t D = Q->ne[0];
79407958
const int64_t T = Q->ne[1];
79417959
const int64_t H = Q->ne[2];
79427960
const int64_t B = Q->ne[3];
79437961

7944-
// Temporary buffer for attention scores for one query row
7945-
std::vector<float> attn_row(T, 0.0f);
7962+
// Dimension validation
7963+
GGML_ASSERT(K->ne[0] == D && V->ne[0] == D);
7964+
GGML_ASSERT(K->ne[1] == T && V->ne[1] == T);
7965+
GGML_ASSERT(K->ne[2] == H && V->ne[2] == H);
7966+
GGML_ASSERT(K->ne[3] == B && V->ne[3] == B);
7967+
7968+
// Parameter sanity checks
7969+
GGML_ASSERT(k_top >= 0 && k_top <= (int32_t)T);
7970+
GGML_ASSERT(win_local >= -1);
7971+
GGML_ASSERT(stride_glb >= 0);
79467972

7947-
const float scale = 1.0f / sqrtf((float) D);
7973+
const float scale = 1.0f / sqrtf((float)D);
7974+
const float NINF = -std::numeric_limits<float>::infinity();
7975+
7976+
// Preallocated buffers to avoid heap churn
7977+
std::vector<float> attn_row((size_t)T, NINF);
7978+
std::vector<int32_t> cand_idx; cand_idx.reserve((size_t)T);
7979+
std::vector<float> scores; scores.reserve((size_t)T);
79487980

7949-
// Loops over batch, head, and query token
79507981
for (int64_t b = 0; b < B; ++b) {
79517982
for (int64_t h = 0; h < H; ++h) {
79527983
for (int64_t iq = 0; iq < T; ++iq) {
79537984

7954-
// (1) Compute dot products Q·K within same (b,h)
7955-
const char * qbase = (const char *) Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1];
7956-
const float * qv = (const float *) qbase;
7985+
// (0) Build candidate index list (always include self)
7986+
cand_idx.clear();
7987+
scores.clear();
7988+
7989+
if (!use_local && !use_stride) {
7990+
// No sparsity: attend to all tokens
7991+
for (int64_t j = 0; j < T; ++j)
7992+
cand_idx.push_back((int32_t)j);
7993+
} else {
7994+
// Apply local window and/or global stride
7995+
for (int64_t j = 0; j < T; ++j) {
7996+
const int64_t dist = iq >= j ? iq - j : j - iq;
7997+
const bool pass_local = use_local && (dist <= (int64_t)win_local);
7998+
const bool pass_stride = use_stride && (stride_glb > 0 && j % stride_glb == 0);
7999+
if (pass_local || pass_stride || j == iq)
8000+
cand_idx.push_back((int32_t)j);
8001+
}
8002+
}
8003+
8004+
// Edge case: no candidates or k_top==0 → output zeros
8005+
if (k_top == 0 || cand_idx.empty()) {
8006+
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8007+
std::fill(y0, y0 + D, 0.0f);
8008+
continue;
8009+
}
79578010

7958-
for (int64_t j = 0; j < T; ++j) {
7959-
const char * kbase = (const char *) K->data + b*K->nb[3] + h*K->nb[2] + j*K->nb[1];
7960-
const float * kv = (const float *) kbase;
8011+
// (1) Compute scaled dot-product Q·K only for candidates
8012+
std::fill(attn_row.begin(), attn_row.end(), NINF);
8013+
const float * qv = (const float *)((const char *)Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1]);
79618014

8015+
for (int32_t j : cand_idx) {
8016+
const float * kv = (const float *)((const char *)K->data + b*K->nb[3] + h*K->nb[2] + (int64_t)j*K->nb[1]);
79628017
float dot = 0.0f;
7963-
for (int64_t d = 0; d < D; ++d) {
8018+
for (int64_t d = 0; d < D; ++d)
79648019
dot += qv[d] * kv[d];
7965-
}
79668020
attn_row[j] = dot * scale;
79678021
}
79688022

7969-
// (2) Select top-k threshold using nth_element
7970-
const int kk = std::max<int>(1, std::min<int>((int)T, k_top));
7971-
std::vector<float> tmp(attn_row.begin(), attn_row.end());
7972-
std::nth_element(tmp.begin(), tmp.begin() + (kk - 1), tmp.end(), std::greater<float>());
7973-
const float thr = tmp[kk - 1];
8023+
// (2) Determine true Top-K threshold using nth_element
8024+
const int num_candidates = (int)cand_idx.size();
8025+
const int kk = std::min<int>(std::max<int>(1, k_top), num_candidates);
8026+
8027+
if (kk < num_candidates) {
8028+
scores.resize((size_t)num_candidates);
8029+
for (size_t i = 0; i < cand_idx.size(); ++i)
8030+
scores[i] = attn_row[cand_idx[i]];
8031+
8032+
std::nth_element(scores.begin(), scores.begin() + (kk - 1), scores.end(), std::greater<float>());
8033+
const float thr = scores[kk - 1];
79748034

7975-
for (int64_t j = 0; j < T; ++j) {
7976-
if (attn_row[j] < thr) attn_row[j] = -INFINITY;
8035+
// Mask all values below the threshold
8036+
for (int32_t j : cand_idx)
8037+
if (attn_row[j] < thr) attn_row[j] = NINF;
79778038
}
79788039

7979-
// (3) Numerically stable softmax on the masked row
7980-
float maxv = -INFINITY;
7981-
for (int64_t j = 0; j < T; ++j) {
8040+
// (3) Numerically stable softmax
8041+
float maxv = NINF;
8042+
for (int32_t j : cand_idx)
79828043
maxv = std::max(maxv, attn_row[j]);
8044+
8045+
// Handle all-masked rows
8046+
if (!std::isfinite(maxv)) {
8047+
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8048+
std::fill(y0, y0 + D, 0.0f);
8049+
continue;
79838050
}
8051+
79848052
float sum = 0.0f;
7985-
for (int64_t j = 0; j < T; ++j) {
7986-
float v = attn_row[j] - maxv;
7987-
float e = expf(v);
8053+
for (int32_t j : cand_idx) {
8054+
if (attn_row[j] == NINF) continue;
8055+
const float e = expf(attn_row[j] - maxv);
79888056
attn_row[j] = e;
79898057
sum += e;
79908058
}
7991-
const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f;
7992-
for (int64_t j = 0; j < T; ++j) {
8059+
8060+
const float inv_sum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
8061+
for (int32_t j : cand_idx) {
8062+
if (attn_row[j] == NINF) continue;
79938063
attn_row[j] *= inv_sum;
79948064
}
79958065

7996-
// (4) Compute output = A·V (weighted sum)
7997-
float * y = (float *) ((char *) dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
7998-
8066+
// (4) Compute output y = A·V
8067+
float * y = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
79998068
for (int64_t d = 0; d < D; ++d) {
80008069
float acc = 0.0f;
8001-
for (int64_t j = 0; j < T; ++j) {
8070+
for (int32_t j : cand_idx) {
80028071
const float aij = attn_row[j];
8003-
if (aij == 0.0f) continue; // skip masked entries
8004-
const char * vbase = (const char *) V->data + b*V->nb[3] + h*V->nb[2] + j*V->nb[1];
8005-
const float * vv = (const float *) vbase;
8072+
if (!(aij > 0.0f)) continue; // skip zero or masked
8073+
const float * vv = (const float *)((const char *)V->data + b*V->nb[3] + h*V->nb[2] + (int64_t)j*V->nb[1]);
80068074
acc += aij * vv[d];
80078075
}
80088076
y[d] = acc;
@@ -8012,7 +8080,7 @@ static void ggml_compute_forward_sparsek_attn_f32(
80128080
}
80138081

80148082
GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
8015-
k_top, win_local, stride_glb);
8083+
k_top, win_local, stride_glb);
80168084
}
80178085

80188086
void ggml_compute_forward_sparsek_attn(

0 commit comments

Comments
 (0)