Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,

GGML_OP_SPARSEK_ATTN,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
Expand Down Expand Up @@ -2231,6 +2231,26 @@ extern "C" {
// n_head % ne32 == 0
// ne3 % ne33 == 0
//

GGML_API struct ggml_tensor * ggml_sparsek_attn(
struct ggml_context * ctx,
struct ggml_tensor * Q,
struct ggml_tensor * K,
struct ggml_tensor * V,
int32_t k_top,
int32_t win_local,
int32_t stride_global);

GGML_API void ggml_sparsek_attn_set_params(
struct ggml_tensor * a,
int32_t k_top,
int32_t win_local,
int32_t stride_global);

GGML_API int32_t ggml_sparsek_attn_get_param(
const struct ggml_tensor * a,
int index);

GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_SPARSEK_ATTN:
{
ggml_compute_forward_sparsek_attn(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
int32_t t = ggml_get_op_params_i32(tensor, 0);
Expand Down
189 changes: 189 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <float.h>
#include <algorithm>
#include <vector>

// ggml_compute_forward_dup

Expand Down Expand Up @@ -7907,6 +7908,194 @@ void ggml_compute_forward_argsort(
}
}

//------------------------------------------------------------------------------
// SparseK Attention (CPU, final optimized version)
//------------------------------------------------------------------------------
//
// Implements SparseK Attention as a GGML operator for the CPU backend.
// Features:
// • Top-K filtering using nth_element (O(N))
// • Optional local window (win_local)
// • Optional global stride (stride_glb)
// • Numerically stable softmax
// • Preallocated buffers for performance
//
// Author: Yael Shuker & Gitty Burstein
//------------------------------------------------------------------------------

#include <algorithm>
#include <vector>
#include <cmath>
#include <limits>

static void ggml_compute_forward_sparsek_attn_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

// Single-threaded baseline version
if (params->ith != 0) return;

const struct ggml_tensor * Q = dst->src[0];
const struct ggml_tensor * K = dst->src[1];
const struct ggml_tensor * V = dst->src[2];

GGML_ASSERT(Q && K && V);
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F32);
GGML_ASSERT(V->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

// Operator parameters
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
const int32_t win_local = ggml_get_op_params_i32(dst, 1); // -1 ⇒ no local window
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); // ≤1 ⇒ no global stride

const bool use_local = (win_local >= 0);
const bool use_stride = (stride_glb > 1);

// GGML tensor dimensions: ne[0]=D, ne[1]=T, ne[2]=H, ne[3]=B
const int64_t D = Q->ne[0];
const int64_t T = Q->ne[1];
const int64_t H = Q->ne[2];
const int64_t B = Q->ne[3];

// Dimension validation
GGML_ASSERT(K->ne[0] == D && V->ne[0] == D);
GGML_ASSERT(K->ne[1] == T && V->ne[1] == T);
GGML_ASSERT(K->ne[2] == H && V->ne[2] == H);
GGML_ASSERT(K->ne[3] == B && V->ne[3] == B);

// Parameter sanity checks
GGML_ASSERT(k_top >= 0 && k_top <= (int32_t)T);
GGML_ASSERT(win_local >= -1);
GGML_ASSERT(stride_glb >= 0);

const float scale = 1.0f / sqrtf((float)D);
const float NINF = -std::numeric_limits<float>::infinity();

// Preallocated buffers to avoid heap churn
std::vector<float> attn_row((size_t)T, NINF);
std::vector<int32_t> cand_idx; cand_idx.reserve((size_t)T);
std::vector<float> scores; scores.reserve((size_t)T);

for (int64_t b = 0; b < B; ++b) {
for (int64_t h = 0; h < H; ++h) {
for (int64_t iq = 0; iq < T; ++iq) {

// (0) Build candidate index list (always include self)
cand_idx.clear();
scores.clear();

if (!use_local && !use_stride) {
// No sparsity: attend to all tokens
for (int64_t j = 0; j < T; ++j)
cand_idx.push_back((int32_t)j);
} else {
// Apply local window and/or global stride
for (int64_t j = 0; j < T; ++j) {
const int64_t dist = iq >= j ? iq - j : j - iq;
const bool pass_local = use_local && (dist <= (int64_t)win_local);
const bool pass_stride = use_stride && (stride_glb > 0 && j % stride_glb == 0);
if (pass_local || pass_stride || j == iq)
cand_idx.push_back((int32_t)j);
}
}

// Edge case: no candidates or k_top==0 → output zeros
if (k_top == 0 || cand_idx.empty()) {
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
std::fill(y0, y0 + D, 0.0f);
continue;
}

// (1) Compute scaled dot-product Q·K only for candidates
std::fill(attn_row.begin(), attn_row.end(), NINF);
const float * qv = (const float *)((const char *)Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1]);

for (int32_t j : cand_idx) {
const float * kv = (const float *)((const char *)K->data + b*K->nb[3] + h*K->nb[2] + (int64_t)j*K->nb[1]);
float dot = 0.0f;
for (int64_t d = 0; d < D; ++d)
dot += qv[d] * kv[d];
attn_row[j] = dot * scale;
}

// (2) Determine true Top-K threshold using nth_element
const int num_candidates = (int)cand_idx.size();
const int kk = std::min<int>(std::max<int>(1, k_top), num_candidates);

if (kk < num_candidates) {
scores.resize((size_t)num_candidates);
for (size_t i = 0; i < cand_idx.size(); ++i)
scores[i] = attn_row[cand_idx[i]];

std::nth_element(scores.begin(), scores.begin() + (kk - 1), scores.end(), std::greater<float>());
const float thr = scores[kk - 1];

// Mask all values below the threshold
for (int32_t j : cand_idx)
if (attn_row[j] < thr) attn_row[j] = NINF;
}

// (3) Numerically stable softmax
float maxv = NINF;
for (int32_t j : cand_idx)
maxv = std::max(maxv, attn_row[j]);

// Handle all-masked rows
if (!std::isfinite(maxv)) {
float * y0 = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
std::fill(y0, y0 + D, 0.0f);
continue;
}

float sum = 0.0f;
for (int32_t j : cand_idx) {
if (attn_row[j] == NINF) continue;
const float e = expf(attn_row[j] - maxv);
attn_row[j] = e;
sum += e;
}

const float inv_sum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
for (int32_t j : cand_idx) {
if (attn_row[j] == NINF) continue;
attn_row[j] *= inv_sum;
}

// (4) Compute output y = A·V
float * y = (float *)((char *)dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
for (int64_t d = 0; d < D; ++d) {
float acc = 0.0f;
for (int32_t j : cand_idx) {
const float aij = attn_row[j];
if (!(aij > 0.0f)) continue; // skip zero or masked
const float * vv = (const float *)((const char *)V->data + b*V->nb[3] + h*V->nb[2] + (int64_t)j*V->nb[1]);
acc += aij * vv[d];
}
y[d] = acc;
}
}
}
}

GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
k_top, win_local, stride_glb);
}

void ggml_compute_forward_sparsek_attn(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
switch (dst->type) {
case GGML_TYPE_F32:
ggml_compute_forward_sparsek_attn_f32(params, dst);
break;
default:
GGML_ASSERT(false && "sparsek_attn: unsupported dst type");
}
}


// ggml_compute_forward_flash_attn_ext

static void ggml_compute_forward_flash_attn_ext_f16(
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params *
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst);

void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
Expand Down
54 changes: 50 additions & 4 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",

"SPARSEK_ATTN",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
"SSM_CONV",
Expand Down Expand Up @@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",

"sparsek_attn(x)",
"flash_attn_ext(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
Expand Down Expand Up @@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -5063,6 +5063,52 @@ struct ggml_tensor * ggml_top_k(
return result;
}

// ggml_sparsek_attn
struct ggml_tensor * ggml_sparsek_attn(
struct ggml_context * ctx,
struct ggml_tensor * Q,
struct ggml_tensor * K,
struct ggml_tensor * V,
int32_t k_top,
int32_t win_local,
int32_t stride_global) {

GGML_ASSERT(ggml_can_mul_mat(K, Q));
GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]);

int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);


int32_t params_i32[3] = { k_top, win_local, stride_global };
ggml_set_op_params(result, params_i32, sizeof(params_i32));

result->op = GGML_OP_SPARSEK_ATTN;
result->src[0] = Q;
result->src[1] = K;
result->src[2] = V;

return result;
}


void ggml_sparsek_attn_set_params(struct ggml_tensor * a,
int32_t k_top,
int32_t win_local,
int32_t stride_global) {
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
ggml_set_op_params_i32(a, 0, k_top);
ggml_set_op_params_i32(a, 1, win_local);
ggml_set_op_params_i32(a, 2, stride_global);
}

int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) {
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
return ggml_get_op_params_i32(a, index);
}



// ggml_flash_attn_ext

struct ggml_tensor * ggml_flash_attn_ext(
Expand Down
Loading
Loading