Skip to content

Commit 596508b

Browse files
Gitty BursteinGittyBursteinyael-works
committed
cleanup: remove unrelated helper function
Co-authored-by: Gitty Burstein <[email protected]> Co-authored-by: Yael Shuker <[email protected]>
1 parent e241f65 commit 596508b

File tree

3 files changed

+0
-191
lines changed

3 files changed

+0
-191
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 0 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -7930,194 +7930,6 @@ void ggml_compute_forward_argsort(
79307930
}
79317931
}
79327932

7933-
//------------------------------------------------------------------------------
7934-
// SparseK Attention (CPU, final optimized version)
7935-
//------------------------------------------------------------------------------
7936-
//
7937-
// Implements SparseK Attention as a GGML operator for the CPU backend.
7938-
// Features:
7939-
// • Top-K filtering using nth_element (O(N))
7940-
// • Optional local window (win_local)
7941-
// • Optional global stride (stride_glb)
7942-
// • Numerically stable softmax
7943-
// • Preallocated buffers for performance
7944-
//
7945-
// Author: Yael Shuker & Gitty Burstein
7946-
//------------------------------------------------------------------------------
7947-
7948-
#include <algorithm>
7949-
#include <vector>
7950-
#include <cmath>
7951-
#include <limits>
7952-
7953-
static void ggml_compute_forward_sparsek_attn_f32(
7954-
const struct ggml_compute_params * params,
7955-
struct ggml_tensor * dst) {
7956-
7957-
// Single-threaded baseline version
7958-
if (params->ith != 0) return;
7959-
7960-
const struct ggml_tensor * Q = dst->src[0];
7961-
const struct ggml_tensor * K = dst->src[1];
7962-
const struct ggml_tensor * V = dst->src[2];
7963-
7964-
GGML_ASSERT(Q && K && V);
7965-
GGML_ASSERT(Q->type == GGML_TYPE_F32);
7966-
GGML_ASSERT(K->type == GGML_TYPE_F32);
7967-
GGML_ASSERT(V->type == GGML_TYPE_F32);
7968-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
7969-
7970-
// Operator parameters
7971-
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
7972-
const int32_t win_local = ggml_get_op_params_i32(dst, 1); // -1 ⇒ no local window
7973-
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); // ≤1 ⇒ no global stride
7974-
7975-
const bool use_local = (win_local >= 0);
7976-
const bool use_stride = (stride_glb > 1);
7977-
7978-
// GGML tensor dimensions: ne[0]=D, ne[1]=T, ne[2]=H, ne[3]=B
7979-
const int64_t D = Q->ne[0];
7980-
const int64_t T = Q->ne[1];
7981-
const int64_t H = Q->ne[2];
7982-
const int64_t B = Q->ne[3];
7983-
7984-
// Dimension validation
7985-
GGML_ASSERT(K->ne[0] == D && V->ne[0] == D);
7986-
GGML_ASSERT(K->ne[1] == T && V->ne[1] == T);
7987-
GGML_ASSERT(K->ne[2] == H && V->ne[2] == H);
7988-
GGML_ASSERT(K->ne[3] == B && V->ne[3] == B);
7989-
7990-
// Parameter sanity checks
7991-
GGML_ASSERT(k_top >= 0 && k_top <= (int32_t)T);
7992-
GGML_ASSERT(win_local >= -1);
7993-
GGML_ASSERT(stride_glb >= 0);
7994-
7995-
const float scale = 1.0f / sqrtf((float)D);
7996-
const float NINF = -std::numeric_limits<float>::infinity();
7997-
7998-
// Preallocated buffers to avoid heap churn
7999-
std::vector<float> attn_row((size_t)T, NINF);
8000-
std::vector<int32_t> cand_idx; cand_idx.reserve((size_t)T);
8001-
std::vector<float> scores; scores.reserve((size_t)T);
8002-
8003-
for (int64_t b = 0; b < B; ++b) {
8004-
for (int64_t h = 0; h < H; ++h) {
8005-
for (int64_t iq = 0; iq < T; ++iq) {
8006-
8007-
// (0) Build candidate index list (always include self)
8008-
cand_idx.clear();
8009-
scores.clear();
8010-
8011-
if (!use_local && !use_stride) {
8012-
// No sparsity: attend to all tokens
8013-
for (int64_t j = 0; j < T; ++j)
8014-
cand_idx.push_back((int32_t)j);
8015-
} else {
8016-
// Apply local window and/or global stride
8017-
for (int64_t j = 0; j < T; ++j) {
8018-
const int64_t dist = iq >= j ? iq - j : j - iq;
8019-
const bool pass_local = use_local && (dist <= (int64_t)win_local);
8020-
const bool pass_stride = use_stride && (stride_glb > 0 && j % stride_glb == 0);
8021-
if (pass_local || pass_stride || j == iq)
8022-
cand_idx.push_back((int32_t)j);
8023-
}
8024-
}
8025-
8026-
// Edge case: no candidates or k_top==0 → output zeros
8027-
if (k_top == 0 || cand_idx.empty()) {
8028-
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8029-
std::fill(y0, y0 + D, 0.0f);
8030-
continue;
8031-
}
8032-
8033-
// (1) Compute scaled dot-product Q·K only for candidates
8034-
std::fill(attn_row.begin(), attn_row.end(), NINF);
8035-
const float * qv = (const float *)((const char *)Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1]);
8036-
8037-
for (int32_t j : cand_idx) {
8038-
const float * kv = (const float *)((const char *)K->data + b*K->nb[3] + h*K->nb[2] + (int64_t)j*K->nb[1]);
8039-
float dot = 0.0f;
8040-
for (int64_t d = 0; d < D; ++d)
8041-
dot += qv[d] * kv[d];
8042-
attn_row[j] = dot * scale;
8043-
}
8044-
8045-
// (2) Determine true Top-K threshold using nth_element
8046-
const int num_candidates = (int)cand_idx.size();
8047-
const int kk = std::min<int>(std::max<int>(1, k_top), num_candidates);
8048-
8049-
if (kk < num_candidates) {
8050-
scores.resize((size_t)num_candidates);
8051-
for (size_t i = 0; i < cand_idx.size(); ++i)
8052-
scores[i] = attn_row[cand_idx[i]];
8053-
8054-
std::nth_element(scores.begin(), scores.begin() + (kk - 1), scores.end(), std::greater<float>());
8055-
const float thr = scores[kk - 1];
8056-
8057-
// Mask all values below the threshold
8058-
for (int32_t j : cand_idx)
8059-
if (attn_row[j] < thr) attn_row[j] = NINF;
8060-
}
8061-
8062-
// (3) Numerically stable softmax
8063-
float maxv = NINF;
8064-
for (int32_t j : cand_idx)
8065-
maxv = std::max(maxv, attn_row[j]);
8066-
8067-
// Handle all-masked rows
8068-
if (!std::isfinite(maxv)) {
8069-
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8070-
std::fill(y0, y0 + D, 0.0f);
8071-
continue;
8072-
}
8073-
8074-
float sum = 0.0f;
8075-
for (int32_t j : cand_idx) {
8076-
if (attn_row[j] == NINF) continue;
8077-
const float e = expf(attn_row[j] - maxv);
8078-
attn_row[j] = e;
8079-
sum += e;
8080-
}
8081-
8082-
const float inv_sum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
8083-
for (int32_t j : cand_idx) {
8084-
if (attn_row[j] == NINF) continue;
8085-
attn_row[j] *= inv_sum;
8086-
}
8087-
8088-
// (4) Compute output y = A·V
8089-
float * y = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
8090-
for (int64_t d = 0; d < D; ++d) {
8091-
float acc = 0.0f;
8092-
for (int32_t j : cand_idx) {
8093-
const float aij = attn_row[j];
8094-
if (!(aij > 0.0f)) continue; // skip zero or masked
8095-
const float * vv = (const float *)((const char *)V->data + b*V->nb[3] + h*V->nb[2] + (int64_t)j*V->nb[1]);
8096-
acc += aij * vv[d];
8097-
}
8098-
y[d] = acc;
8099-
}
8100-
}
8101-
}
8102-
}
8103-
8104-
GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
8105-
k_top, win_local, stride_glb);
8106-
}
8107-
8108-
void ggml_compute_forward_sparsek_attn(
8109-
const struct ggml_compute_params * params,
8110-
struct ggml_tensor * dst) {
8111-
switch (dst->type) {
8112-
case GGML_TYPE_F32:
8113-
ggml_compute_forward_sparsek_attn_f32(params, dst);
8114-
break;
8115-
default:
8116-
GGML_ASSERT(false && "sparsek_attn: unsupported dst type");
8117-
}
8118-
}
8119-
8120-
81217933
// ggml_compute_forward_flash_attn_ext
81227934

81237935
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

ggml/src/ggml-cpu/ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params *
8686
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8787
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8888
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
89-
void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
9089

9190
void ggml_compute_forward_flash_attn_back(
9291
const struct ggml_compute_params * params,

tests/test-backend-ops.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5528,8 +5528,6 @@ struct test_sparsek_attn : public test_case {
55285528
}
55295529
};
55305530

5531-
5532-
55335531
// GGML_OP_FLAsH_ATTN_EXT
55345532
struct test_flash_attn_ext : public test_case {
55355533
const int64_t hsk; // K head size

0 commit comments

Comments
 (0)