Skip to content

Commit 5d6d3b7

Browse files
committed
Add CPU support for SparseK Attention (without performance checks)
Co-authored-by: Yael Shuker <[email protected]> Co-authored-by: Gitty Burstein <[email protected]>
1 parent 66248d2 commit 5d6d3b7

File tree

5 files changed

+174
-22
lines changed

5 files changed

+174
-22
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19521952
{
19531953
ggml_compute_forward_flash_attn_ext(params, tensor);
19541954
} break;
1955+
case GGML_OP_SPARSEK_ATTN:
1956+
{
1957+
ggml_compute_forward_sparsek_attn(params, tensor);
1958+
break;
1959+
}
19551960
case GGML_OP_FLASH_ATTN_BACK:
19561961
{
19571962
int32_t t = ggml_get_op_params_i32(tensor, 0);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7907,6 +7907,88 @@ void ggml_compute_forward_argsort(
79077907
}
79087908
}
79097909

7910+
//------------------------------------------------------------------------------
7911+
// SparseK Attention (CPU)
7912+
//------------------------------------------------------------------------------
7913+
7914+
static void ggml_compute_forward_sparsek_attn_f32(
7915+
const struct ggml_compute_params * params,
7916+
struct ggml_tensor * dst) {
7917+
7918+
if (params->ith != 0) return; // main thread only
7919+
7920+
const struct ggml_tensor * Q = dst->src[0];
7921+
const struct ggml_tensor * K = dst->src[1];
7922+
const struct ggml_tensor * V = dst->src[2];
7923+
7924+
GGML_ASSERT(Q && K && V);
7925+
GGML_ASSERT(Q->type == GGML_TYPE_F32);
7926+
GGML_ASSERT(K->type == GGML_TYPE_F32);
7927+
GGML_ASSERT(V->type == GGML_TYPE_F32);
7928+
7929+
const int32_t k_top = ggml_get_op_params_i32(dst, 0);
7930+
const int32_t win_local = ggml_get_op_params_i32(dst, 1);
7931+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2);
7932+
7933+
const int64_t D = Q->ne[0]; // embedding dim
7934+
const int64_t T = Q->ne[1]; // sequence length
7935+
7936+
const float * q = (const float *) Q->data;
7937+
const float * k = (const float *) K->data;
7938+
const float * v = (const float *) V->data;
7939+
float * out = (float *) dst->data;
7940+
7941+
7942+
for (int64_t i = 0; i < T; ++i) {
7943+
for (int64_t j = 0; j < T; ++j) {
7944+
float dot = 0.0f;
7945+
for (int64_t d = 0; d < D; ++d)
7946+
dot += q[i*D + d] * k[j*D + d];
7947+
out[i*T + j] = dot / sqrtf((float) D);
7948+
}
7949+
}
7950+
7951+
for (int64_t i = 0; i < T; ++i) {
7952+
float * row = &out[i*T];
7953+
for (int64_t j = 0; j < T; ++j)
7954+
if (row[j] < row[k_top]) row[j] = -INFINITY;
7955+
}
7956+
7957+
for (int64_t i = 0; i < T; ++i) {
7958+
float maxv = -INFINITY;
7959+
for (int64_t j = 0; j < T; ++j)
7960+
if (out[i*T + j] > maxv) maxv = out[i*T + j];
7961+
float sum = 0.0f;
7962+
for (int64_t j = 0; j < T; ++j) {
7963+
out[i*T + j] = expf(out[i*T + j] - maxv);
7964+
sum += out[i*T + j];
7965+
}
7966+
for (int64_t j = 0; j < T; ++j)
7967+
out[i*T + j] /= sum;
7968+
}
7969+
7970+
7971+
float * result = (float *) dst->data;
7972+
for (int64_t i = 0; i < T; ++i) {
7973+
for (int64_t d = 0; d < D; ++d) {
7974+
float sum = 0.0f;
7975+
for (int64_t j = 0; j < T; ++j)
7976+
sum += out[i*T + j] * v[j*D + d];
7977+
result[i*D + d] = sum;
7978+
}
7979+
}
7980+
7981+
GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
7982+
k_top, win_local, stride_glb);
7983+
}
7984+
7985+
void ggml_compute_forward_sparsek_attn(
7986+
const struct ggml_compute_params * params,
7987+
struct ggml_tensor * dst) {
7988+
ggml_compute_forward_sparsek_attn_f32(params, dst);
7989+
}
7990+
7991+
79107992
// ggml_compute_forward_flash_attn_ext
79117993

79127994
static void ggml_compute_forward_flash_attn_ext_f16(

ggml/src/ggml-cpu/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params *
8686
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8787
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8888
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
89+
void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
90+
8991
void ggml_compute_forward_flash_attn_back(
9092
const struct ggml_compute_params * params,
9193
const bool masked,

ggml/src/ggml.c

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
990990
"TIMESTEP_EMBEDDING",
991991
"ARGSORT",
992992
"LEAKY_RELU",
993-
993+
"SPARSEK_ATTN",
994994
"FLASH_ATTN_EXT",
995995
"FLASH_ATTN_BACK",
996996
"SSM_CONV",
@@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10941094
"timestep_embedding(timesteps, dim, max_period)",
10951095
"argsort(x)",
10961096
"leaky_relu(x)",
1097-
"sparsek_attn(Q, K, V, k_top, win_local, stride_global)",
1097+
"sparsek_attn(x)",
10981098
"flash_attn_ext(x)",
10991099
"flash_attn_back(x)",
11001100
"ssm_conv(x)",
@@ -5073,36 +5073,42 @@ struct ggml_tensor * ggml_sparsek_attn(
50735073
int32_t win_local,
50745074
int32_t stride_global) {
50755075

5076-
// ביטול אזהרות (אם טרם משתמשים בפרמטרים)
5077-
GGML_UNUSED(k_top);
5078-
GGML_UNUSED(win_local);
5079-
GGML_UNUSED(stride_global);
5080-
5081-
// בדיקות תקינות בסיסיות
5082-
GGML_ASSERT(Q != NULL);
5083-
GGML_ASSERT(K != NULL);
5084-
GGML_ASSERT(V != NULL);
50855076
GGML_ASSERT(ggml_can_mul_mat(K, Q));
5077+
GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]);
5078+
5079+
int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
5080+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5081+
50865082

5087-
// יצירת טנזור פלט בממדים המתאימים
5088-
int64_t ne[GGML_MAX_DIMS] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
5089-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
5083+
int32_t params_i32[3] = { k_top, win_local, stride_global };
5084+
ggml_set_op_params(result, params_i32, sizeof(params_i32));
50905085

5091-
// הגדרת סוג האופרטור והמקורות
50925086
result->op = GGML_OP_SPARSEK_ATTN;
50935087
result->src[0] = Q;
50945088
result->src[1] = K;
50955089
result->src[2] = V;
50965090

5097-
// שמירת הפרמטרים המספריים במערך op_params (שיטה הנהוגה ב־ggml)
5098-
result->op_params[0] = k_top;
5099-
result->op_params[1] = win_local;
5100-
result->op_params[2] = stride_global;
5101-
51025091
return result;
51035092
}
51045093

51055094

5095+
void ggml_sparsek_attn_set_params(struct ggml_tensor * a,
5096+
int32_t k_top,
5097+
int32_t win_local,
5098+
int32_t stride_global) {
5099+
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5100+
ggml_set_op_params_i32(a, 0, k_top);
5101+
ggml_set_op_params_i32(a, 1, win_local);
5102+
ggml_set_op_params_i32(a, 2, stride_global);
5103+
}
5104+
5105+
int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) {
5106+
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5107+
return ggml_get_op_params_i32(a, index);
5108+
}
5109+
5110+
5111+
51065112
// ggml_flash_attn_ext
51075113

51085114
struct ggml_tensor * ggml_flash_attn_ext(

tests/test-backend-ops.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,7 @@ struct test_example : public test_case {
17781778
};
17791779

17801780

1781+
17811782
// GGML_OP_UNARY
17821783
struct test_unary : public test_case {
17831784
const ggml_unary_op op;
@@ -5362,7 +5363,46 @@ struct test_leaky_relu : public test_case {
53625363
}
53635364
};
53645365

5365-
// GGML_OP_FLASH_ATTN_EXT
5366+
// GGML_OP_SPARSEK_ATTN
5367+
struct test_sparsek_attn : public test_case {
5368+
const int64_t d_qk;
5369+
const int64_t d_v;
5370+
const int64_t n_head;
5371+
const int64_t n_tokens;
5372+
const int64_t batch;
5373+
const int32_t k_top;
5374+
const int32_t win_local;
5375+
const int32_t stride_global;
5376+
5377+
std::string vars() override {
5378+
return VARS_TO_STR9(d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0);
5379+
}
5380+
5381+
test_sparsek_attn(int64_t d_qk = 128, int64_t d_v = 128, int64_t n_head = 8,
5382+
int64_t n_tokens = 256, int64_t batch = 4,
5383+
int32_t k_top = 32, int32_t win_local = 64, int32_t stride_global = 128)
5384+
: d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch),
5385+
k_top(k_top), win_local(win_local), stride_global(stride_global) {}
5386+
5387+
ggml_tensor * build_graph(ggml_context * ctx) override {
5388+
const int64_t n_q = n_tokens;
5389+
ggml_tensor * Q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch);
5390+
ggml_set_name(Q, "Q");
5391+
ggml_tensor * K = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch);
5392+
ggml_set_name(K, "K");
5393+
ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch);
5394+
ggml_set_name(V, "V");
5395+
5396+
ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global);
5397+
ggml_set_name(out, "SPARSEK_ATTN_out");
5398+
5399+
return out;
5400+
}
5401+
};
5402+
5403+
5404+
5405+
// GGML_OP_FLAsH_ATTN_EXT
53665406
struct test_flash_attn_ext : public test_case {
53675407
const int64_t hsk; // K head size
53685408
const int64_t hsv; // V head size
@@ -7095,7 +7135,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70957135
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
70967136
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
70977137
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
7098-
7138+
70997139
for (bool mask : { true, false } ) {
71007140
for (bool sinks : { true, false } ) {
71017141
for (float max_bias : { 0.0f, 8.0f }) {
@@ -7134,6 +7174,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
71347174
}
71357175
}
71367176
}
7177+
// ---- SPARSEK_ATTN --------------------------------------------------
7178+
for (int64_t d_qk : {64, 128}) {
7179+
for (int64_t d_v : {64, 128}) {
7180+
for (int64_t n_head : {4, 8}) {
7181+
for (int64_t kv : {113, 512}) {
7182+
for (int64_t b : {1, 4}) {
7183+
for (int32_t k_top : {16, 32}) {
7184+
for (int32_t win_local : {32, 64}) {
7185+
test_cases.emplace_back(new test_sparsek_attn(
7186+
d_qk, d_v, n_head, kv, b, k_top, win_local, /*stride_global*/128));
7187+
}
7188+
}
7189+
}
7190+
}
7191+
}
7192+
}
7193+
}
71377194

71387195
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
71397196
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));

0 commit comments

Comments
 (0)