@@ -261,17 +261,12 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261    }
262262}
263263
264- static  void  print_mask (const   float  * data, int64_t  n_tokens, int64_t  n_kv, int64_t  n_swa, llama_swa_type swa_type) {
264+ static  void  print_mask (float  * data, int64_t  n_tokens, int64_t  n_kv, int64_t  n_swa, llama_swa_type swa_type) {
265265    LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n "  , __func__);
266-     const  char  * swa_type_str = " unknown"  ;
267- 
268-     switch  (swa_type) {
269-         case  LLAMA_SWA_TYPE_NONE:      swa_type_str = " LLAMA_SWA_TYPE_NONE"  ; break ;
270-         case  LLAMA_SWA_TYPE_STANDARD:  swa_type_str = " LLAMA_SWA_TYPE_STANDARD"  ; break ;
271-         case  LLAMA_SWA_TYPE_CHUNKED:   swa_type_str = " LLAMA_SWA_TYPE_CHUNKED"  ; break ;
272-         case  LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = " LLAMA_SWA_TYPE_SYMMETRIC"  ; break ;
273-     };
274- 
266+     const  char  * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? " LLAMA_SWA_TYPE_NONE"   :
267+                           (swa_type == LLAMA_SWA_TYPE_STANDARD) ? " LLAMA_SWA_TYPE_STANDARD"   :
268+                           (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? " LLAMA_SWA_TYPE_CHUNKED"   :
269+                           (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? " LLAMA_SWA_TYPE_SYMMETRIC"   : " unknown"  ;
275270    LLAMA_LOG_DEBUG (" %s: n_swa : %d, n_kv: %d, swq_type: %s\n "  , __func__, (int )n_swa, (int )n_kv, swa_type_str);
276271    LLAMA_LOG_DEBUG (" %s: '0' = can attend, '∞' = masked\n "  , __func__);
277272    LLAMA_LOG_DEBUG (" %s: Rows = query tokens, Columns = key/value tokens\n\n "  , __func__);
@@ -300,67 +295,50 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
300295    const  int64_t  n_kv     = ubatch->n_tokens ;
301296    const  int64_t  n_tokens = ubatch->n_tokens ;
302297
303-     const  auto  fill_mask = [&](float  * data, int  n_swa, llama_swa_type swa_type) {
304-         for  (int  h = 0 ; h < 1 ; ++h) {
305-             for  (int  i1 = 0 ; i1 < n_tokens; ++i1) {
306-                 const  llama_seq_id s1 = ubatch->seq_id [i1][0 ];
307-                 const  llama_pos    p1 = ubatch->pos [i1];
298+     GGML_ASSERT (kq_mask);
299+     GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
300+ 
301+     float  * data = (float  *) kq_mask->data ;
308302
309-                 const  uint64_t  idst = h*(n_kv*n_tokens) + i1*n_kv;
303+     //  [TAG_NO_CACHE_ISWA]
304+     GGML_ASSERT (hparams.swa_type  == LLAMA_SWA_TYPE_NONE && " TODO: implement"  );
310305
311-                 for  (int  i0 = 0 ; i0 < n_tokens; ++i0) {
306+     for  (int  h = 0 ; h < 1 ; ++h) {
307+         for  (int  i1 = 0 ; i1 < n_tokens; ++i1) {
308+             const  llama_seq_id s1 = ubatch->seq_id [i1][0 ];
309+ 
310+             for  (int  i0 = 0 ; i0 < n_tokens; ++i0) {
311+                 float  f = -INFINITY;
312+ 
313+                 for  (int  s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
312314                    const  llama_seq_id s0 = ubatch->seq_id [i0][0 ];
313-                     const  llama_pos p0    = ubatch->pos [i0];
314315
315-                     //  mask different sequences
316316                    if  (s0 != s1) {
317-                         continue ;
317+                         continue ;  //  skip different sequences 
318318                    }
319319
320-                     //  mask future tokens
321-                     if  (cparams.causal_attn  && p0 > p1) {
322-                         continue ;
320+                     if  (cparams.causal_attn  && ubatch->pos [i0] > ubatch->pos [i1]) {
321+                         continue ; //  skip future tokens for causal attention
323322                    }
324323
325-                     //  apply  SWA if any 
326-                     if  (llama_hparams:: is_masked_swa (n_swa, swa_type, p0, p1 )) {
327-                         continue ;
328-                     }
324+                     //  TODO: this does not take into account that some layers are  SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] 
325+                     // if (hparams. is_masked_swa(ubatch->pos[i0], ubatch->pos[i1] )) {
326+                     //     continue; // skip masked tokens for SWA 
327+                     // }
329328
330-                     data[idst + i0] = hparams.use_alibi  ? -std::abs (p0 - p1) : 0 .0f ;
329+                     //  TODO: reimplement this like in llama_kv_cache_unified
330+                     if  (hparams.use_alibi ) {
331+                         f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
332+                     } else  {
333+                         f = 0 .0f ;
334+                     }
331335                }
336+                 data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
332337            }
333338        }
334-     };
335- 
336-     {
337-         GGML_ASSERT (self_kq_mask);
338-         GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
339- 
340-         float  * data = (float  *) self_kq_mask->data ;
341- 
342-         std::fill (data, data + ggml_nelements (self_kq_mask), -INFINITY);
343- 
344-         fill_mask (data, 0 , LLAMA_SWA_TYPE_NONE);
345- 
346-         if  (debug) {
347-             print_mask (data, n_tokens, n_kv, 0 , LLAMA_SWA_TYPE_NONE);
348-         }
349339    }
350- 
351-     if  (hparams.swa_type  != LLAMA_SWA_TYPE_NONE) {
352-         GGML_ASSERT (self_kq_mask_swa);
353-         GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
354- 
355-         float  * data = (float  *) self_kq_mask_swa->data ;
356- 
357-         std::fill (data, data + ggml_nelements (self_kq_mask_swa), -INFINITY);
358- 
359-         fill_mask (data, hparams.n_swa , hparams.swa_type );
360- 
361-         if  (debug) {
362-             print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
363-         }
340+     if  (debug) {
341+         print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
364342    }
365343}
366344
@@ -1357,10 +1335,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
13571335    k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
13581336    v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
13591337
1338+     const  auto  n_kv = k->ne [1 ];
1339+ 
13601340    ggml_tensor * cur;
13611341
13621342    //  TODO: replace hardcoded padding with ggml-provided padding
1363-     if  (cparams.flash_attn  && kq_b == nullptr ) {
1343+     if  (cparams.flash_attn  && (n_kv %  256  ==  0 ) &&  kq_b == nullptr ) {
13641344        GGML_ASSERT (kq_b == nullptr  && " Flash attention does not support KQ bias yet"  );
13651345
13661346        if  (v_trans) {
@@ -1475,20 +1455,10 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14751455    auto  inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14761456
14771457    //  note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1478-     inp->self_kq_mask  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1479-     ggml_set_input (inp->self_kq_mask );
1480- 
1481-     inp->self_kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1458+     inp->kq_mask  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1459+     ggml_set_input (inp->kq_mask );
14821460
1483-     if  (hparams.swa_type  != LLAMA_SWA_TYPE_NONE) {
1484-         inp->self_kq_mask_swa  = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1485-         ggml_set_input (inp->self_kq_mask_swa );
1486- 
1487-         inp->self_kq_mask_swa_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1488-     } else  {
1489-         inp->self_kq_mask_swa      = nullptr ;
1490-         inp->self_kq_mask_swa_cnv  = nullptr ;
1491-     }
1461+     inp->kq_mask_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
14921462
14931463    return  (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
14941464}
@@ -1513,9 +1483,7 @@ ggml_tensor * llm_graph_context::build_attn(
15131483    ggml_build_forward_expand (gf, k_cur);
15141484    ggml_build_forward_expand (gf, v_cur);
15151485
1516-     const  bool  is_swa = hparams.is_swa (il);
1517- 
1518-     const  auto  & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1486+     const  auto  & kq_mask = inp->get_kq_mask ();
15191487
15201488    //  [TAG_NO_CACHE_PAD]
15211489    //  TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
0 commit comments