@@ -261,17 +261,12 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261 }
262262}
263263
264- static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+ static void print_mask (float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265 LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n " , __func__);
266- const char * swa_type_str = " unknown" ;
267-
268- switch (swa_type) {
269- case LLAMA_SWA_TYPE_NONE: swa_type_str = " LLAMA_SWA_TYPE_NONE" ; break ;
270- case LLAMA_SWA_TYPE_STANDARD: swa_type_str = " LLAMA_SWA_TYPE_STANDARD" ; break ;
271- case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = " LLAMA_SWA_TYPE_CHUNKED" ; break ;
272- case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = " LLAMA_SWA_TYPE_SYMMETRIC" ; break ;
273- };
274-
266+ const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? " LLAMA_SWA_TYPE_NONE" :
267+ (swa_type == LLAMA_SWA_TYPE_STANDARD) ? " LLAMA_SWA_TYPE_STANDARD" :
268+ (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? " LLAMA_SWA_TYPE_CHUNKED" :
269+ (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? " LLAMA_SWA_TYPE_SYMMETRIC" : " unknown" ;
275270 LLAMA_LOG_DEBUG (" %s: n_swa : %d, n_kv: %d, swq_type: %s\n " , __func__, (int )n_swa, (int )n_kv, swa_type_str);
276271 LLAMA_LOG_DEBUG (" %s: '0' = can attend, '∞' = masked\n " , __func__);
277272 LLAMA_LOG_DEBUG (" %s: Rows = query tokens, Columns = key/value tokens\n\n " , __func__);
@@ -300,67 +295,50 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
300295 const int64_t n_kv = ubatch->n_tokens ;
301296 const int64_t n_tokens = ubatch->n_tokens ;
302297
303- const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304- for (int h = 0 ; h < 1 ; ++h) {
305- for (int i1 = 0 ; i1 < n_tokens; ++i1) {
306- const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
307- const llama_pos p1 = ubatch->pos [i1];
298+ GGML_ASSERT (kq_mask);
299+ GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
300+
301+ float * data = (float *) kq_mask->data ;
308302
309- const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
303+ // [TAG_NO_CACHE_ISWA]
304+ GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " TODO: implement" );
310305
311- for (int i0 = 0 ; i0 < n_tokens; ++i0) {
306+ for (int h = 0 ; h < 1 ; ++h) {
307+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
308+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
309+
310+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
311+ float f = -INFINITY;
312+
313+ for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
312314 const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
313- const llama_pos p0 = ubatch->pos [i0];
314315
315- // mask different sequences
316316 if (s0 != s1) {
317- continue ;
317+ continue ; // skip different sequences
318318 }
319319
320- // mask future tokens
321- if (cparams.causal_attn && p0 > p1) {
322- continue ;
320+ if (cparams.causal_attn && ubatch->pos [i0] > ubatch->pos [i1]) {
321+ continue ; // skip future tokens for causal attention
323322 }
324323
325- // apply SWA if any
326- if (llama_hparams:: is_masked_swa (n_swa, swa_type, p0, p1 )) {
327- continue ;
328- }
324+ // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325+ // if (hparams. is_masked_swa(ubatch->pos[i0], ubatch->pos[i1] )) {
326+ // continue; // skip masked tokens for SWA
327+ // }
329328
330- data[idst + i0] = hparams.use_alibi ? -std::abs (p0 - p1) : 0 .0f ;
329+ // TODO: reimplement this like in llama_kv_cache_unified
330+ if (hparams.use_alibi ) {
331+ f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
332+ } else {
333+ f = 0 .0f ;
334+ }
331335 }
336+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
332337 }
333338 }
334- };
335-
336- {
337- GGML_ASSERT (self_kq_mask);
338- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
339-
340- float * data = (float *) self_kq_mask->data ;
341-
342- std::fill (data, data + ggml_nelements (self_kq_mask), -INFINITY);
343-
344- fill_mask (data, 0 , LLAMA_SWA_TYPE_NONE);
345-
346- if (debug) {
347- print_mask (data, n_tokens, n_kv, 0 , LLAMA_SWA_TYPE_NONE);
348- }
349339 }
350-
351- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352- GGML_ASSERT (self_kq_mask_swa);
353- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
354-
355- float * data = (float *) self_kq_mask_swa->data ;
356-
357- std::fill (data, data + ggml_nelements (self_kq_mask_swa), -INFINITY);
358-
359- fill_mask (data, hparams.n_swa , hparams.swa_type );
360-
361- if (debug) {
362- print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
363- }
340+ if (debug) {
341+ print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
364342 }
365343}
366344
@@ -1321,10 +1299,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
13211299 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
13221300 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
13231301
1302+ const auto n_kv = k->ne [1 ];
1303+
13241304 ggml_tensor * cur;
13251305
13261306 // TODO: replace hardcoded padding with ggml-provided padding
1327- if (cparams.flash_attn && kq_b == nullptr ) {
1307+ if (cparams.flash_attn && (n_kv % 256 == 0 ) && kq_b == nullptr ) {
13281308 GGML_ASSERT (kq_b == nullptr && " Flash attention does not support KQ bias yet" );
13291309
13301310 if (v_trans) {
@@ -1439,20 +1419,10 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14391419 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14401420
14411421 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1442- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1443- ggml_set_input (inp->self_kq_mask );
1444-
1445- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1422+ inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1423+ ggml_set_input (inp->kq_mask );
14461424
1447- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1448- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1449- ggml_set_input (inp->self_kq_mask_swa );
1450-
1451- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1452- } else {
1453- inp->self_kq_mask_swa = nullptr ;
1454- inp->self_kq_mask_swa_cnv = nullptr ;
1455- }
1425+ inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
14561426
14571427 return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
14581428}
@@ -1477,9 +1447,7 @@ ggml_tensor * llm_graph_context::build_attn(
14771447 ggml_build_forward_expand (gf, k_cur);
14781448 ggml_build_forward_expand (gf, v_cur);
14791449
1480- const bool is_swa = hparams.is_swa (il);
1481-
1482- const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1450+ const auto & kq_mask = inp->get_kq_mask ();
14831451
14841452 // [TAG_NO_CACHE_PAD]
14851453 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
0 commit comments