Skip to content

Commit 39a117f

Browse files
Gitty Bursteinyael-worksGittyBurstein
committed
fix SparseK CPU operator implementation
Co-authored-by: Yael Shuker <[email protected]> Co-authored-by: Gitty Burstein <[email protected]>
1 parent a5daf2f commit 39a117f

File tree

5 files changed

+81
-184
lines changed

5 files changed

+81
-184
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,8 +1955,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19551955
case GGML_OP_SPARSEK_ATTN:
19561956
{
19571957
ggml_compute_forward_sparsek_attn(params, tensor);
1958-
break;
1959-
}
1958+
} break;
19601959
case GGML_OP_FLASH_ATTN_BACK:
19611960
{
19621961
int32_t t = ggml_get_op_params_i32(tensor, 0);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <float.h>
1111
#include <algorithm>
12+
#include <vector>
1213

1314
// ggml_compute_forward_dup
1415

@@ -7915,7 +7916,8 @@ static void ggml_compute_forward_sparsek_attn_f32(
79157916
const struct ggml_compute_params * params,
79167917
struct ggml_tensor * dst) {
79177918

7918-
if (params->ith != 0) return; // main thread only
7919+
// Single-threaded baseline version (expand later for parallelism)
7920+
if (params->ith != 0) return;
79197921

79207922
const struct ggml_tensor * Q = dst->src[0];
79217923
const struct ggml_tensor * K = dst->src[1];
@@ -7925,56 +7927,87 @@ static void ggml_compute_forward_sparsek_attn_f32(
79257927
GGML_ASSERT(Q->type == GGML_TYPE_F32);
79267928
GGML_ASSERT(K->type == GGML_TYPE_F32);
79277929
GGML_ASSERT(V->type == GGML_TYPE_F32);
7930+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
79287931

79297932
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
79307933
const int32_t win_local = ggml_get_op_params_i32(dst, 1);
79317934
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2);
7935+
GGML_UNUSED(win_local);
7936+
GGML_UNUSED(stride_glb);
79327937

7933-
const int64_t D = Q->ne[0]; // embedding dim
7934-
const int64_t T = Q->ne[1]; // sequence length
7938+
// Tensor dimensions according to GGML layout: ne[0]=d, ne[1]=seq, ne[2]=head, ne[3]=batch
7939+
const int64_t D = Q->ne[0];
7940+
const int64_t T = Q->ne[1];
7941+
const int64_t H = Q->ne[2];
7942+
const int64_t B = Q->ne[3];
79357943

7936-
const float * q = (const float *) Q->data;
7937-
const float * k = (const float *) K->data;
7938-
const float * v = (const float *) V->data;
7939-
float * out = (float *) dst->data;
7944+
// Temporary buffer for attention scores for one query row
7945+
std::vector<float> attn_row(T, 0.0f);
79407946

7941-
7942-
for (int64_t i = 0; i < T; ++i) {
7943-
for (int64_t j = 0; j < T; ++j) {
7944-
float dot = 0.0f;
7945-
for (int64_t d = 0; d < D; ++d)
7946-
dot += q[i*D + d] * k[j*D + d];
7947-
out[i*T + j] = dot / sqrtf((float) D);
7948-
}
7949-
}
7947+
const float scale = 1.0f / sqrtf((float) D);
79507948

7951-
for (int64_t i = 0; i < T; ++i) {
7952-
float * row = &out[i*T];
7953-
for (int64_t j = 0; j < T; ++j)
7954-
if (row[j] < row[k_top]) row[j] = -INFINITY;
7955-
}
7949+
// Loops over batch, head, and query token
7950+
for (int64_t b = 0; b < B; ++b) {
7951+
for (int64_t h = 0; h < H; ++h) {
7952+
for (int64_t iq = 0; iq < T; ++iq) {
79567953

7957-
for (int64_t i = 0; i < T; ++i) {
7958-
float maxv = -INFINITY;
7959-
for (int64_t j = 0; j < T; ++j)
7960-
if (out[i*T + j] > maxv) maxv = out[i*T + j];
7961-
float sum = 0.0f;
7962-
for (int64_t j = 0; j < T; ++j) {
7963-
out[i*T + j] = expf(out[i*T + j] - maxv);
7964-
sum += out[i*T + j];
7965-
}
7966-
for (int64_t j = 0; j < T; ++j)
7967-
out[i*T + j] /= sum;
7968-
}
7954+
// (1) Compute dot products Q·K within same (b,h)
7955+
const char * qbase = (const char *) Q->data + b*Q->nb[3] + h*Q->nb[2] + iq*Q->nb[1];
7956+
const float * qv = (const float *) qbase;
79697957

7958+
for (int64_t j = 0; j < T; ++j) {
7959+
const char * kbase = (const char *) K->data + b*K->nb[3] + h*K->nb[2] + j*K->nb[1];
7960+
const float * kv = (const float *) kbase;
79707961

7971-
float * result = (float *) dst->data;
7972-
for (int64_t i = 0; i < T; ++i) {
7973-
for (int64_t d = 0; d < D; ++d) {
7974-
float sum = 0.0f;
7975-
for (int64_t j = 0; j < T; ++j)
7976-
sum += out[i*T + j] * v[j*D + d];
7977-
result[i*D + d] = sum;
7962+
float dot = 0.0f;
7963+
for (int64_t d = 0; d < D; ++d) {
7964+
dot += qv[d] * kv[d];
7965+
}
7966+
attn_row[j] = dot * scale;
7967+
}
7968+
7969+
// (2) Select top-k threshold using nth_element
7970+
const int kk = std::max<int>(1, std::min<int>((int)T, k_top));
7971+
std::vector<float> tmp(attn_row.begin(), attn_row.end());
7972+
std::nth_element(tmp.begin(), tmp.begin() + (kk - 1), tmp.end(), std::greater<float>());
7973+
const float thr = tmp[kk - 1];
7974+
7975+
for (int64_t j = 0; j < T; ++j) {
7976+
if (attn_row[j] < thr) attn_row[j] = -INFINITY;
7977+
}
7978+
7979+
// (3) Numerically stable softmax on the masked row
7980+
float maxv = -INFINITY;
7981+
for (int64_t j = 0; j < T; ++j) {
7982+
maxv = std::max(maxv, attn_row[j]);
7983+
}
7984+
float sum = 0.0f;
7985+
for (int64_t j = 0; j < T; ++j) {
7986+
float v = attn_row[j] - maxv;
7987+
float e = expf(v);
7988+
attn_row[j] = e;
7989+
sum += e;
7990+
}
7991+
const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f;
7992+
for (int64_t j = 0; j < T; ++j) {
7993+
attn_row[j] *= inv_sum;
7994+
}
7995+
7996+
// (4) Compute output = A·V (weighted sum)
7997+
float * y = (float *) ((char *) dst->data + b*dst->nb[3] + h*dst->nb[2] + iq*dst->nb[1]);
7998+
7999+
for (int64_t d = 0; d < D; ++d) {
8000+
float acc = 0.0f;
8001+
for (int64_t j = 0; j < T; ++j) {
8002+
const float aij = attn_row[j];
8003+
if (aij == 0.0f) continue; // skip masked entries
8004+
const char * vbase = (const char *) V->data + b*V->nb[3] + h*V->nb[2] + j*V->nb[1];
8005+
const float * vv = (const float *) vbase;
8006+
acc += aij * vv[d];
8007+
}
8008+
y[d] = acc;
8009+
}
8010+
}
79788011
}
79798012
}
79808013

@@ -7985,7 +8018,13 @@ static void ggml_compute_forward_sparsek_attn_f32(
79858018
void ggml_compute_forward_sparsek_attn(
79868019
const struct ggml_compute_params * params,
79878020
struct ggml_tensor * dst) {
7988-
ggml_compute_forward_sparsek_attn_f32(params, dst);
8021+
switch (dst->type) {
8022+
case GGML_TYPE_F32:
8023+
ggml_compute_forward_sparsek_attn_f32(params, dst);
8024+
break;
8025+
default:
8026+
GGML_ASSERT(false && "sparsek_attn: unsupported dst type");
8027+
}
79898028
}
79908029

79918030

ggml/tests/test_sparsek_cpu.c

Lines changed: 0 additions & 40 deletions
This file was deleted.

tests/test_sparsek_cpu.c

Lines changed: 0 additions & 50 deletions
This file was deleted.

tmp-test/test_sparsek_cpu.c

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)