@@ -7909,14 +7909,30 @@ void ggml_compute_forward_argsort(
79097909}
79107910
79117911// ------------------------------------------------------------------------------
7912- //  SparseK Attention (CPU)
7912+ //  SparseK Attention (CPU, final optimized version )
79137913// ------------------------------------------------------------------------------
7914+ // 
7915+ //  Implements SparseK Attention as a GGML operator for the CPU backend.
7916+ //  Features:
7917+ //   • Top-K filtering using nth_element (O(N))
7918+ //   • Optional local window (win_local)
7919+ //   • Optional global stride (stride_glb)
7920+ //   • Numerically stable softmax
7921+ //   • Preallocated buffers for performance
7922+ // 
7923+ //  Author: Yael Shuker (yael-works)
7924+ // ------------------------------------------------------------------------------
7925+ 
7926+ #include  < algorithm> 
7927+ #include  < vector> 
7928+ #include  < cmath> 
7929+ #include  < limits> 
79147930
79157931static  void  ggml_compute_forward_sparsek_attn_f32 (
79167932    const  struct  ggml_compute_params  * params,
79177933    struct  ggml_tensor  * dst) {
79187934
7919-     //  Single-threaded baseline version  
7935+     //  Single-threaded baseline version
79207936    if  (params->ith  != 0 ) return ;
79217937
79227938    const  struct  ggml_tensor  * Q = dst->src [0 ];
@@ -7929,80 +7945,132 @@ static void ggml_compute_forward_sparsek_attn_f32(
79297945    GGML_ASSERT (V->type  == GGML_TYPE_F32);
79307946    GGML_ASSERT (dst->type  == GGML_TYPE_F32);
79317947
7948+     //  Operator parameters
79327949    const  int32_t  k_top      = ggml_get_op_params_i32 (dst, 0 );
7933-     const  int32_t  win_local  = ggml_get_op_params_i32 (dst, 1 );
7934-     const  int32_t  stride_glb = ggml_get_op_params_i32 (dst, 2 );
7935-     GGML_UNUSED (win_local);
7936-     GGML_UNUSED (stride_glb);
7950+     const  int32_t  win_local  = ggml_get_op_params_i32 (dst, 1 ); //  -1 ⇒ no local window
7951+     const  int32_t  stride_glb = ggml_get_op_params_i32 (dst, 2 ); //  ≤1 ⇒ no global stride
7952+ 
7953+     const  bool  use_local  = (win_local  >= 0 );
7954+     const  bool  use_stride = (stride_glb >  1 );
79377955
7938-     //  Tensor dimensions according to  GGML layout : ne[0]=d , ne[1]=seq , ne[2]=head , ne[3]=batch 
7956+     //  GGML tensor dimensions : ne[0]=D , ne[1]=T , ne[2]=H , ne[3]=B 
79397957    const  int64_t  D = Q->ne [0 ];
79407958    const  int64_t  T = Q->ne [1 ];
79417959    const  int64_t  H = Q->ne [2 ];
79427960    const  int64_t  B = Q->ne [3 ];
79437961
7944-     //  Temporary buffer for attention scores for one query row
7945-     std::vector<float > attn_row (T, 0 .0f );
7962+     //  Dimension validation
7963+     GGML_ASSERT (K->ne [0 ] == D && V->ne [0 ] == D);
7964+     GGML_ASSERT (K->ne [1 ] == T && V->ne [1 ] == T);
7965+     GGML_ASSERT (K->ne [2 ] == H && V->ne [2 ] == H);
7966+     GGML_ASSERT (K->ne [3 ] == B && V->ne [3 ] == B);
7967+ 
7968+     //  Parameter sanity checks
7969+     GGML_ASSERT (k_top >= 0  && k_top <= (int32_t )T);
7970+     GGML_ASSERT (win_local >= -1 );
7971+     GGML_ASSERT (stride_glb >= 0 );
79467972
7947-     const  float  scale = 1 .0f  / sqrtf ((float ) D);
7973+     const  float  scale = 1 .0f  / sqrtf ((float )D);
7974+     const  float  NINF  = -std::numeric_limits<float >::infinity ();
7975+ 
7976+     //  Preallocated buffers to avoid heap churn
7977+     std::vector<float >   attn_row ((size_t )T, NINF);
7978+     std::vector<int32_t > cand_idx; cand_idx.reserve ((size_t )T);
7979+     std::vector<float >   scores;   scores.reserve ((size_t )T);
79487980
7949-     //  Loops over batch, head, and query token
79507981    for  (int64_t  b = 0 ; b < B; ++b) {
79517982        for  (int64_t  h = 0 ; h < H; ++h) {
79527983            for  (int64_t  iq = 0 ; iq < T; ++iq) {
79537984
7954-                 //  (1) Compute dot products Q·K within same (b,h)
7955-                 const  char  * qbase = (const  char  *) Q->data  + b*Q->nb [3 ] + h*Q->nb [2 ] + iq*Q->nb [1 ];
7956-                 const  float  * qv = (const  float  *) qbase;
7985+                 //  (0) Build candidate index list (always include self)
7986+                 cand_idx.clear ();
7987+                 scores.clear ();
7988+ 
7989+                 if  (!use_local && !use_stride) {
7990+                     //  No sparsity: attend to all tokens
7991+                     for  (int64_t  j = 0 ; j < T; ++j)
7992+                         cand_idx.push_back ((int32_t )j);
7993+                 } else  {
7994+                     //  Apply local window and/or global stride
7995+                     for  (int64_t  j = 0 ; j < T; ++j) {
7996+                         const  int64_t  dist = iq >= j ? iq - j : j - iq;
7997+                         const  bool  pass_local  = use_local  && (dist <= (int64_t )win_local);
7998+                         const  bool  pass_stride = use_stride && (stride_glb > 0  && j % stride_glb == 0 );
7999+                         if  (pass_local || pass_stride || j == iq)
8000+                             cand_idx.push_back ((int32_t )j);
8001+                     }
8002+                 }
8003+ 
8004+                 //  Edge case: no candidates or k_top==0 → output zeros
8005+                 if  (k_top == 0  || cand_idx.empty ()) {
8006+                     float  * y0 = (float  *)((char  *)dst->data  + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
8007+                     std::fill (y0, y0 + D, 0 .0f );
8008+                     continue ;
8009+                 }
79578010
7958-                 for  ( int64_t  j =  0 ; j < T; ++j) { 
7959-                      const   char  * kbase = ( const   char  *) K-> data  + b*K-> nb [ 3 ] + h*K-> nb [ 2 ] + j*K-> nb [ 1 ] ;
7960-                      const  float  * kv  = (const  float  *) kbase ;
8011+                 //  (1) Compute scaled dot-product Q·K only for candidates 
8012+                 std::fill (attn_row. begin (), attn_row. end (), NINF) ;
8013+                 const  float  * qv  = (const  float  *)(( const   char  *)Q-> data  + b*Q-> nb [ 3 ] + h*Q-> nb [ 2 ] + iq*Q-> nb [ 1 ]) ;
79618014
8015+                 for  (int32_t  j : cand_idx) {
8016+                     const  float  * kv = (const  float  *)((const  char  *)K->data  + b*K->nb [3 ] + h*K->nb [2 ] + (int64_t )j*K->nb [1 ]);
79628017                    float  dot = 0 .0f ;
7963-                     for  (int64_t  d = 0 ; d < D; ++d) { 
8018+                     for  (int64_t  d = 0 ; d < D; ++d)
79648019                        dot += qv[d] * kv[d];
7965-                     }
79668020                    attn_row[j] = dot * scale;
79678021                }
79688022
7969-                 //  (2) Select top-k threshold using nth_element
7970-                 const  int  kk = std::max<int >(1 , std::min<int >((int )T, k_top));
7971-                 std::vector<float > tmp (attn_row.begin (), attn_row.end ());
7972-                 std::nth_element (tmp.begin (), tmp.begin () + (kk - 1 ), tmp.end (), std::greater<float >());
7973-                 const  float  thr = tmp[kk - 1 ];
8023+                 //  (2) Determine true Top-K threshold using nth_element
8024+                 const  int  num_candidates = (int )cand_idx.size ();
8025+                 const  int  kk = std::min<int >(std::max<int >(1 , k_top), num_candidates);
8026+ 
8027+                 if  (kk < num_candidates) {
8028+                     scores.resize ((size_t )num_candidates);
8029+                     for  (size_t  i = 0 ; i < cand_idx.size (); ++i)
8030+                         scores[i] = attn_row[cand_idx[i]];
8031+ 
8032+                     std::nth_element (scores.begin (), scores.begin () + (kk - 1 ), scores.end (), std::greater<float >());
8033+                     const  float  thr = scores[kk - 1 ];
79748034
7975-                 for  (int64_t  j = 0 ; j < T; ++j) {
7976-                     if  (attn_row[j] < thr) attn_row[j] = -INFINITY;
8035+                     //  Mask all values below the threshold
8036+                     for  (int32_t  j : cand_idx)
8037+                         if  (attn_row[j] < thr) attn_row[j] = NINF;
79778038                }
79788039
7979-                 //  (3) Numerically stable softmax on the masked row 
7980-                 float  maxv = -INFINITY ;
7981-                 for  (int64_t  j =  0 ; j < T; ++j) { 
8040+                 //  (3) Numerically stable softmax
8041+                 float  maxv = NINF ;
8042+                 for  (int32_t  j : cand_idx) 
79828043                    maxv = std::max (maxv, attn_row[j]);
8044+ 
8045+                 //  Handle all-masked rows
8046+                 if  (!std::isfinite (maxv)) {
8047+                     float  * y0 = (float  *)((char  *)dst->data  + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
8048+                     std::fill (y0, y0 + D, 0 .0f );
8049+                     continue ;
79838050                }
8051+ 
79848052                float  sum = 0 .0f ;
7985-                 for  (int64_t  j =  0 ; j < T; ++j ) {
7986-                     float  v =  attn_row[j] - maxv ;
7987-                     float  e = expf (v );
8053+                 for  (int32_t  j : cand_idx ) {
8054+                     if  ( attn_row[j] == NINF)  continue ;
8055+                     const   float  e = expf (attn_row[j] - maxv );
79888056                    attn_row[j] = e;
79898057                    sum += e;
79908058                }
7991-                 const  float  inv_sum = sum > 0 .0f  ? 1 .0f  / sum : 0 .0f ;
7992-                 for  (int64_t  j = 0 ; j < T; ++j) {
8059+ 
8060+                 const  float  inv_sum = (sum > 0 .0f ) ? (1 .0f  / sum) : 0 .0f ;
8061+                 for  (int32_t  j : cand_idx) {
8062+                     if  (attn_row[j] == NINF) continue ;
79938063                    attn_row[j] *= inv_sum;
79948064                }
79958065
7996-                 //  (4) Compute output = A·V (weighted sum)
7997-                 float  * y = (float  *) ((char  *) dst->data  + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
7998- 
8066+                 //  (4) Compute output y = A·V
8067+                 float  * y = (float  *)((char  *)dst->data  + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
79998068                for  (int64_t  d = 0 ; d < D; ++d) {
80008069                    float  acc = 0 .0f ;
8001-                     for  (int64_t  j =  0 ; j < T; ++j ) {
8070+                     for  (int32_t  j : cand_idx ) {
80028071                        const  float  aij = attn_row[j];
8003-                         if  (aij == 0 .0f ) continue ; //  skip masked entries
8004-                         const  char  * vbase = (const  char  *) V->data  + b*V->nb [3 ] + h*V->nb [2 ] + j*V->nb [1 ];
8005-                         const  float  * vv = (const  float  *) vbase;
8072+                         if  (!(aij > 0 .0f )) continue ; //  skip zero or masked
8073+                         const  float  * vv = (const  float  *)((const  char  *)V->data  + b*V->nb [3 ] + h*V->nb [2 ] + (int64_t )j*V->nb [1 ]);
80068074                        acc += aij * vv[d];
80078075                    }
80088076                    y[d] = acc;
@@ -8012,7 +8080,7 @@ static void ggml_compute_forward_sparsek_attn_f32(
80128080    }
80138081
80148082    GGML_PRINT_DEBUG (" [SPARSEK CPU] k_top=%d win_local=%d stride=%d\n "  ,
8015-         k_top, win_local, stride_glb);
8083+                       k_top, win_local, stride_glb);
80168084}
80178085
80188086void  ggml_compute_forward_sparsek_attn (
0 commit comments