@@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10191019 "GLU" ,
10201020};
10211021
1022- static_assert (GGML_OP_COUNT == 90 , "GGML_OP_COUNT != 90 " );
1022+ static_assert (GGML_OP_COUNT == 91 , "GGML_OP_COUNT != 91 " );
10231023
10241024static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
10251025 "none" ,
@@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10941094 "timestep_embedding(timesteps, dim, max_period)" ,
10951095 "argsort(x)" ,
10961096 "leaky_relu(x)" ,
1097-
1097+ "sparsek_attn(Q, K, V, k_top, win_local, stride_global)" ,
10981098 "flash_attn_ext(x)" ,
10991099 "flash_attn_back(x)" ,
11001100 "ssm_conv(x)" ,
@@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11231123 "glu(x)" ,
11241124};
11251125
1126- static_assert (GGML_OP_COUNT == 90 , "GGML_OP_COUNT != 90 " );
1126+ static_assert (GGML_OP_COUNT == 91 , "GGML_OP_COUNT != 91 " );
11271127
11281128static_assert (GGML_OP_POOL_COUNT == 2 , "GGML_OP_POOL_COUNT != 2" );
11291129
@@ -5063,6 +5063,46 @@ struct ggml_tensor * ggml_top_k(
50635063 return result ;
50645064}
50655065
5066+ // ggml_sparsek_attn
5067+ struct ggml_tensor * ggml_sparsek_attn (
5068+ struct ggml_context * ctx ,
5069+ struct ggml_tensor * Q ,
5070+ struct ggml_tensor * K ,
5071+ struct ggml_tensor * V ,
5072+ int32_t k_top ,
5073+ int32_t win_local ,
5074+ int32_t stride_global ) {
5075+
5076+ // ביטול אזהרות (אם טרם משתמשים בפרמטרים)
5077+ GGML_UNUSED (k_top );
5078+ GGML_UNUSED (win_local );
5079+ GGML_UNUSED (stride_global );
5080+
5081+ // בדיקות תקינות בסיסיות
5082+ GGML_ASSERT (Q != NULL );
5083+ GGML_ASSERT (K != NULL );
5084+ GGML_ASSERT (V != NULL );
5085+ GGML_ASSERT (ggml_can_mul_mat (K , Q ));
5086+
5087+ // יצירת טנזור פלט בממדים המתאימים
5088+ int64_t ne [GGML_MAX_DIMS ] = { V -> ne [0 ], Q -> ne [2 ], Q -> ne [1 ], Q -> ne [3 ] };
5089+ struct ggml_tensor * result = ggml_new_tensor (ctx , GGML_TYPE_F32 , GGML_MAX_DIMS , ne );
5090+
5091+ // הגדרת סוג האופרטור והמקורות
5092+ result -> op = GGML_OP_SPARSEK_ATTN ;
5093+ result -> src [0 ] = Q ;
5094+ result -> src [1 ] = K ;
5095+ result -> src [2 ] = V ;
5096+
5097+ // שמירת הפרמטרים המספריים במערך op_params (שיטה הנהוגה ב־ggml)
5098+ result -> op_params [0 ] = k_top ;
5099+ result -> op_params [1 ] = win_local ;
5100+ result -> op_params [2 ] = stride_global ;
5101+
5102+ return result ;
5103+ }
5104+
5105+
50665106// ggml_flash_attn_ext
50675107
50685108struct ggml_tensor * ggml_flash_attn_ext (
0 commit comments