@@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030 bool v_trans,
3131 bool offload,
3232 uint32_t kv_size,
33- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33+ uint32_t padding,
34+ uint32_t n_swa,
35+ llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
3436 GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
3537
3638 this ->type_k = type_k;
@@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
640642 return ggml_cpy (ctx, v_cur, v_view);
641643}
642644
643- void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa ) const {
645+ void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
644646 const int64_t n_tokens = ubatch->n_tokens ;
645647 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
646648 const int64_t n_seqs = ubatch->n_seqs ;
@@ -670,38 +672,23 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
670672 const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
671673
672674 for (int i = 0 ; i < n_kv; ++i) {
673- float f;
674- // mask the token if:
675- if (!cells[i].has_seq_id (seq_id) // not the correct sequence
676- || (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
677- ) {
678- f = -INFINITY;
679- } else {
680- if (hparams.use_alibi ) {
681- f = -std::abs (cells[i].pos - pos);
682- } else {
683- f = 0 .0f ;
684- }
685- }
675+ bool masked = false ;
676+
677+ // mask the token if not the same sequence
678+ masked |= !cells[i].has_seq_id (seq_id);
679+
680+ // mask future tokens
681+ masked |= causal_attn && cells[i].pos > pos;
682+
683+ // apply SWA if any
684+ masked |= is_masked_swa (cells[i].pos , pos);
686685
687- if (swa) {
688- // may need to cut off old tokens for sliding window
689- switch (hparams.swa_type ) {
690- case LLAMA_SWA_TYPE_STANDARD:
691- {
692- if (pos - cells[i].pos >= (int32_t ) hparams.n_swa ) {
693- f = -INFINITY;
694- }
695- } break ;
696- case LLAMA_SWA_TYPE_CHUNKED:
697- {
698- const llama_pos pos_chunk_start = (pos / hparams.n_swa ) * hparams.n_swa ;
699-
700- if (cells[i].pos < pos_chunk_start) {
701- f = -INFINITY;
702- }
703- } break ;
704- }
686+ float f = 0 .0f ;
687+
688+ if (masked) {
689+ f = -INFINITY;
690+ } else if (hparams.use_alibi ) {
691+ f = -std::abs (cells[i].pos - pos);
705692 }
706693
707694 data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -1191,6 +1178,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11911178 return 0 ;
11921179}
11931180
1181+ bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1182+ switch (swa_type) {
1183+ case LLAMA_SWA_TYPE_NONE:
1184+ {
1185+ } break ;
1186+ case LLAMA_SWA_TYPE_STANDARD:
1187+ {
1188+ if (p1 - p0 >= (int32_t ) n_swa) {
1189+ return true ;
1190+ }
1191+ } break ;
1192+ case LLAMA_SWA_TYPE_CHUNKED:
1193+ {
1194+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1195+
1196+ if (p0 < pos_chunk_start) {
1197+ return true ;
1198+ }
1199+ } break ;
1200+ }
1201+
1202+ return false ;
1203+ }
1204+
11941205void llama_kv_cache_unified::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
11951206 std::vector<std::pair<uint32_t , uint32_t >> cell_ranges; // ranges, from inclusive, to exclusive
11961207 uint32_t cell_count = 0 ;
@@ -1586,11 +1597,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15861597
15871598 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
15881599
1589- kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1600+ kv_base = std::make_unique<llama_kv_cache_unified>(
1601+ model, std::move (filter_base), type_k, type_v,
1602+ v_trans, offload, kv_size_base, padding,
1603+ 0 , LLAMA_SWA_TYPE_NONE);
15901604
15911605 LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, kv_size_swa);
15921606
1593- kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1607+ kv_swa = std::make_unique<llama_kv_cache_unified>(
1608+ model, std::move (filter_swa), type_k, type_v,
1609+ v_trans, offload, kv_size_swa, padding,
1610+ hparams.n_swa , hparams.swa_type );
15941611}
15951612
15961613void llama_kv_cache_unified_iswa::clear () {
@@ -2801,5 +2818,4 @@ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
28012818void llama_kv_cache_view_update (llama_kv_cache_view * , const llama_kv_cache * ) {
28022819 // TODO: will be removed soon, keep this for now to avoid too many changes in
28032820 // https://github.com/ggml-org/llama.cpp/pull/13194
2804- GGML_ABORT (" not implemented" );
28052821}
0 commit comments