@@ -1778,6 +1778,7 @@ struct test_example : public test_case {
17781778};
17791779
17801780
1781+ 
17811782//  GGML_OP_UNARY
17821783struct  test_unary  : public  test_case  {
17831784    const  ggml_unary_op op;
@@ -5362,7 +5363,46 @@ struct test_leaky_relu : public test_case {
53625363    }
53635364};
53645365
5365- //  GGML_OP_FLASH_ATTN_EXT
5366+ //  GGML_OP_SPARSEK_ATTN
5367+ struct  test_sparsek_attn  : public  test_case  {
5368+     const  int64_t  d_qk;
5369+     const  int64_t  d_v;
5370+     const  int64_t  n_head;
5371+     const  int64_t  n_tokens;
5372+     const  int64_t  batch;
5373+     const  int32_t  k_top;
5374+     const  int32_t  win_local;
5375+     const  int32_t  stride_global;
5376+ 
5377+     std::string vars () override  {
5378+         return  VARS_TO_STR9 (d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0 );
5379+     }
5380+ 
5381+     test_sparsek_attn (int64_t  d_qk = 128 , int64_t  d_v = 128 , int64_t  n_head = 8 ,
5382+                       int64_t  n_tokens = 256 , int64_t  batch = 4 ,
5383+                       int32_t  k_top = 32 , int32_t  win_local = 64 , int32_t  stride_global = 128 )
5384+         : d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch),
5385+           k_top (k_top), win_local(win_local), stride_global(stride_global) {}
5386+ 
5387+     ggml_tensor * build_graph (ggml_context * ctx) override  {
5388+         const  int64_t  n_q = n_tokens;
5389+         ggml_tensor * Q = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch);
5390+         ggml_set_name (Q, " Q" 
5391+         ggml_tensor * K = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch);
5392+         ggml_set_name (K, " K" 
5393+         ggml_tensor * V = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch);
5394+         ggml_set_name (V, " V" 
5395+ 
5396+         ggml_tensor * out = ggml_sparsek_attn (ctx, Q, K, V, k_top, win_local, stride_global);
5397+         ggml_set_name (out, " SPARSEK_ATTN_out" 
5398+ 
5399+         return  out;
5400+     }
5401+ };
5402+ 
5403+ 
5404+ 
5405+ //  GGML_OP_FLAsH_ATTN_EXT
53665406struct  test_flash_attn_ext  : public  test_case  {
53675407    const  int64_t  hsk; //  K head size
53685408    const  int64_t  hsv; //  V head size
@@ -7095,7 +7135,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70957135            if  (hsk != 192  && hsk != 576  && hsk != hsv) continue ;
70967136            if  (hsk == 192  && (hsv != 128  && hsv != 192 )) continue ;
70977137            if  (hsk == 576  && hsv != 512 ) continue ; //  DeepSeek MLA
7098- 
7138+            
70997139            for  (bool  mask : { true , false  } ) {
71007140                for  (bool  sinks : { true , false  } ) {
71017141                    for  (float  max_bias : { 0 .0f , 8 .0f  }) {
@@ -7134,6 +7174,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
71347174            }
71357175        }
71367176    }
7177+     //  ---- SPARSEK_ATTN --------------------------------------------------
7178+     for  (int64_t  d_qk : {64 , 128 }) {
7179+         for  (int64_t  d_v : {64 , 128 }) {
7180+             for  (int64_t  n_head : {4 , 8 }) {
7181+                 for  (int64_t  kv : {113 , 512 }) {  
7182+                     for  (int64_t  b : {1 , 4 }) {
7183+                         for  (int32_t  k_top : {16 , 32 }) {
7184+                             for  (int32_t  win_local : {32 , 64 }) {
7185+                                 test_cases.emplace_back (new  test_sparsek_attn (
7186+                                     d_qk, d_v, n_head, kv, b, k_top, win_local, /* stride_global*/ 128 ));
7187+                             }
7188+                         }
7189+                     }
7190+                 }
7191+             }
7192+         }
7193+     }
71377194
71387195    test_cases.emplace_back (new  test_cross_entropy_loss      (GGML_TYPE_F32, {   10 , 5 , 4 , 3 }));
71397196    test_cases.emplace_back (new  test_cross_entropy_loss      (GGML_TYPE_F32, {30000 , 1 , 1 , 1 }));
0 commit comments