Skip to content

Commit 66248d2

Browse files
committed
Add skeleton for GGML_OP_SPARSEK_ATTN (SparseK Attention): new operator definition and tensor creation, backend implementation pending to ggml.c/h
Co-authored-by: Yael Shuker <[email protected]> Co-authored-by: Gitty Burstein <[email protected]>
1 parent 3479efd commit 66248d2

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

ggml/include/ggml.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ extern "C" {
529529
GGML_OP_TIMESTEP_EMBEDDING,
530530
GGML_OP_ARGSORT,
531531
GGML_OP_LEAKY_RELU,
532-
532+
GGML_OP_SPARSEK_ATTN,
533533
GGML_OP_FLASH_ATTN_EXT,
534534
GGML_OP_FLASH_ATTN_BACK,
535535
GGML_OP_SSM_CONV,
@@ -2231,6 +2231,16 @@ extern "C" {
22312231
// n_head % ne32 == 0
22322232
// ne3 % ne33 == 0
22332233
//
2234+
2235+
GGML_API struct ggml_tensor * ggml_sparsek_attn(
2236+
struct ggml_context * ctx,
2237+
struct ggml_tensor * Q,
2238+
struct ggml_tensor * K,
2239+
struct ggml_tensor * V,
2240+
int32_t k_top,
2241+
int32_t win_local,
2242+
int32_t stride_global);
2243+
22342244
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
22352245
struct ggml_context * ctx,
22362246
struct ggml_tensor * q,

ggml/src/ggml.c

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10191019
"GLU",
10201020
};
10211021

1022-
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
1022+
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
10231023

10241024
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10251025
"none",
@@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10941094
"timestep_embedding(timesteps, dim, max_period)",
10951095
"argsort(x)",
10961096
"leaky_relu(x)",
1097-
1097+
"sparsek_attn(Q, K, V, k_top, win_local, stride_global)",
10981098
"flash_attn_ext(x)",
10991099
"flash_attn_back(x)",
11001100
"ssm_conv(x)",
@@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11231123
"glu(x)",
11241124
};
11251125

1126-
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
1126+
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
11271127

11281128
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11291129

@@ -5063,6 +5063,46 @@ struct ggml_tensor * ggml_top_k(
50635063
return result;
50645064
}
50655065

5066+
// ggml_sparsek_attn
5067+
struct ggml_tensor * ggml_sparsek_attn(
5068+
struct ggml_context * ctx,
5069+
struct ggml_tensor * Q,
5070+
struct ggml_tensor * K,
5071+
struct ggml_tensor * V,
5072+
int32_t k_top,
5073+
int32_t win_local,
5074+
int32_t stride_global) {
5075+
5076+
// ביטול אזהרות (אם טרם משתמשים בפרמטרים)
5077+
GGML_UNUSED(k_top);
5078+
GGML_UNUSED(win_local);
5079+
GGML_UNUSED(stride_global);
5080+
5081+
// בדיקות תקינות בסיסיות
5082+
GGML_ASSERT(Q != NULL);
5083+
GGML_ASSERT(K != NULL);
5084+
GGML_ASSERT(V != NULL);
5085+
GGML_ASSERT(ggml_can_mul_mat(K, Q));
5086+
5087+
// יצירת טנזור פלט בממדים המתאימים
5088+
int64_t ne[GGML_MAX_DIMS] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
5089+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
5090+
5091+
// הגדרת סוג האופרטור והמקורות
5092+
result->op = GGML_OP_SPARSEK_ATTN;
5093+
result->src[0] = Q;
5094+
result->src[1] = K;
5095+
result->src[2] = V;
5096+
5097+
// שמירת הפרמטרים המספריים במערך op_params (שיטה הנהוגה ב־ggml)
5098+
result->op_params[0] = k_top;
5099+
result->op_params[1] = win_local;
5100+
result->op_params[2] = stride_global;
5101+
5102+
return result;
5103+
}
5104+
5105+
50665106
// ggml_flash_attn_ext
50675107

50685108
struct ggml_tensor * ggml_flash_attn_ext(

0 commit comments

Comments
 (0)