@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261 }
262262}
263263
264- static void print_mask (float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+ static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265 LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n " , __func__);
266- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? " LLAMA_SWA_TYPE_NONE" :
267- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? " LLAMA_SWA_TYPE_STANDARD" :
268- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? " LLAMA_SWA_TYPE_CHUNKED" :
269- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? " LLAMA_SWA_TYPE_SYMMETRIC" : " unknown" ;
266+ const char * swa_type_str = " unknown" ;
267+
268+ switch (swa_type) {
269+ case LLAMA_SWA_TYPE_NONE: swa_type_str = " LLAMA_SWA_TYPE_NONE" ; break ;
270+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = " LLAMA_SWA_TYPE_STANDARD" ; break ;
271+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = " LLAMA_SWA_TYPE_CHUNKED" ; break ;
272+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = " LLAMA_SWA_TYPE_SYMMETRIC" ; break ;
273+ };
274+
270275 LLAMA_LOG_DEBUG (" %s: n_swa : %d, n_kv: %d, swq_type: %s\n " , __func__, (int )n_swa, (int )n_kv, swa_type_str);
271276 LLAMA_LOG_DEBUG (" %s: '0' = can attend, '∞' = masked\n " , __func__);
272277 LLAMA_LOG_DEBUG (" %s: Rows = query tokens, Columns = key/value tokens\n\n " , __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300 const int64_t n_kv = ubatch->n_tokens ;
296301 const int64_t n_tokens = ubatch->n_tokens ;
297302
298- GGML_ASSERT (kq_mask);
299- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
300-
301- float * data = (float *) kq_mask->data ;
302-
303- // [TAG_NO_CACHE_ISWA]
304- GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " TODO: implement" );
303+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304+ for (int h = 0 ; h < 1 ; ++h) {
305+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
306+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
307+ const llama_pos p1 = ubatch->pos [i1];
305308
306- for (int h = 0 ; h < 1 ; ++h) {
307- for (int i1 = 0 ; i1 < n_tokens; ++i1) {
308- const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
309+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
309310
310- for (int i0 = 0 ; i0 < n_tokens; ++i0) {
311- float f = -INFINITY;
312-
313- for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
311+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
314312 const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
313+ const llama_pos p0 = ubatch->pos [i0];
315314
315+ // mask different sequences
316316 if (s0 != s1) {
317- continue ; // skip different sequences
317+ continue ;
318318 }
319319
320- if (cparams.causal_attn && ubatch->pos [i0] > ubatch->pos [i1]) {
321- continue ; // skip future tokens for causal attention
320+ // mask future tokens
321+ if (cparams.causal_attn && p0 > p1) {
322+ continue ;
322323 }
323324
324- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325- // if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326- // continue; // skip masked tokens for SWA
327- // }
328-
329- // TODO: reimplement this like in llama_kv_cache_unified
330- if (hparams.use_alibi ) {
331- f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
332- } else {
333- f = 0 .0f ;
325+ // apply SWA if any
326+ if (llama_hparams::is_masked_swa (n_swa, swa_type, p0, p1)) {
327+ continue ;
334328 }
329+
330+ data[idst + i0] = hparams.use_alibi ? -std::abs (p0 - p1) : 0 .0f ;
335331 }
336- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337332 }
338333 }
334+ };
335+
336+ {
337+ GGML_ASSERT (self_kq_mask);
338+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
339+
340+ float * data = (float *) self_kq_mask->data ;
341+
342+ std::fill (data, data + ggml_nelements (self_kq_mask), -INFINITY);
343+
344+ fill_mask (data, 0 , LLAMA_SWA_TYPE_NONE);
345+
346+ if (debug) {
347+ print_mask (data, n_tokens, n_kv, 0 , LLAMA_SWA_TYPE_NONE);
348+ }
339349 }
340- if (debug) {
341- print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
350+
351+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352+ GGML_ASSERT (self_kq_mask_swa);
353+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
354+
355+ float * data = (float *) self_kq_mask_swa->data ;
356+
357+ std::fill (data, data + ggml_nelements (self_kq_mask_swa), -INFINITY);
358+
359+ fill_mask (data, hparams.n_swa , hparams.swa_type );
360+
361+ if (debug) {
362+ print_mask (data, n_tokens, n_kv, hparams.n_swa , hparams.swa_type );
363+ }
342364 }
343365}
344366
@@ -1299,12 +1321,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991321 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
13001322 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
13011323
1302- const auto n_kv = k->ne [1 ];
1303-
13041324 ggml_tensor * cur;
13051325
13061326 // TODO: replace hardcoded padding with ggml-provided padding
1307- if (cparams.flash_attn && (n_kv % 256 == 0 ) && kq_b == nullptr ) {
1327+ if (cparams.flash_attn && kq_b == nullptr ) {
13081328 GGML_ASSERT (kq_b == nullptr && " Flash attention does not support KQ bias yet" );
13091329
13101330 if (v_trans) {
@@ -1419,10 +1439,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191439 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201440
14211441 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422- inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1423- ggml_set_input (inp->kq_mask );
1442+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1443+ ggml_set_input (inp->self_kq_mask );
1444+
1445+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
14241446
1425- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1447+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1448+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1449+ ggml_set_input (inp->self_kq_mask_swa );
1450+
1451+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1452+ } else {
1453+ inp->self_kq_mask_swa = nullptr ;
1454+ inp->self_kq_mask_swa_cnv = nullptr ;
1455+ }
14261456
14271457 return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
14281458}
@@ -1447,7 +1477,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471477 ggml_build_forward_expand (gf, k_cur);
14481478 ggml_build_forward_expand (gf, v_cur);
14491479
1450- const auto & kq_mask = inp->get_kq_mask ();
1480+ const bool is_swa = hparams.is_swa (il);
1481+
1482+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
14511483
14521484 // [TAG_NO_CACHE_PAD]
14531485 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
0 commit comments