@@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030 bool v_trans,
3131 bool offload,
3232 uint32_t kv_size,
33- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33+ uint32_t padding,
34+ uint32_t n_swa,
35+ llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
3436 GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
3537
3638 this ->type_k = type_k;
@@ -594,8 +596,8 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
594596 // note: v->nb[1] > v->nb[2]
595597 return ggml_view_3d (ctx, v,
596598 n, hparams.n_head_kv (il), hparams.n_embd_head_v ,
597- ggml_element_size (v)* v->ne [1 ]*hparams.n_embd_head_v , // v->nb[1]
598- ggml_element_size (v)* v->ne [1 ], // v->nb[2]
599+ ggml_row_size (v-> type , v->ne [1 ]*hparams.n_embd_head_v ) , // v->nb[1]
600+ ggml_row_size (v-> type , v->ne [1 ]) , // v->nb[2]
599601 0 );
600602}
601603
@@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
640642 return ggml_cpy (ctx, v_cur, v_view);
641643}
642644
643- void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa ) const {
645+ void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
644646 const int64_t n_tokens = ubatch->n_tokens ;
645647 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
646648 const int64_t n_seqs = ubatch->n_seqs ;
@@ -667,41 +669,28 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
667669 const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
668670
669671 for (int j = 0 ; j < n_seq_tokens; ++j) {
670- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
672+ const llama_pos p1 = ubatch->pos [s*n_seq_tokens + j];
671673
672674 for (int i = 0 ; i < n_kv; ++i) {
673- float f;
674- // mask the token if:
675- if (!cells[i].has_seq_id (seq_id) // not the correct sequence
676- || (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
677- ) {
678- f = -INFINITY;
679- } else {
680- if (hparams.use_alibi ) {
681- f = -std::abs (cells[i].pos - pos);
682- } else {
683- f = 0 .0f ;
684- }
685- }
675+ const llama_pos p0 = cells[i].pos ;
676+
677+ bool masked = false ;
678+
679+ // mask the token if not the same sequence
680+ masked = masked || (!cells[i].has_seq_id (seq_id));
681+
682+ // mask future tokens
683+ masked = masked || (causal_attn && p0 > p1);
686684
687- if (swa) {
688- // may need to cut off old tokens for sliding window
689- switch (hparams.swa_type ) {
690- case LLAMA_SWA_TYPE_STANDARD:
691- {
692- if (pos - cells[i].pos >= (int32_t ) hparams.n_swa ) {
693- f = -INFINITY;
694- }
695- } break ;
696- case LLAMA_SWA_TYPE_CHUNKED:
697- {
698- const llama_pos pos_chunk_start = (pos / hparams.n_swa ) * hparams.n_swa ;
699-
700- if (cells[i].pos < pos_chunk_start) {
701- f = -INFINITY;
702- }
703- } break ;
704- }
685+ // apply SWA if any
686+ masked = masked || (is_masked_swa (p0, p1));
687+
688+ float f = 0 .0f ;
689+
690+ if (masked) {
691+ f = -INFINITY;
692+ } else if (hparams.use_alibi ) {
693+ f = -std::abs (p0 - p1);
705694 }
706695
707696 data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -1191,6 +1180,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11911180 return 0 ;
11921181}
11931182
1183+ bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1184+ switch (swa_type) {
1185+ case LLAMA_SWA_TYPE_NONE:
1186+ {
1187+ } break ;
1188+ case LLAMA_SWA_TYPE_STANDARD:
1189+ {
1190+ if (p1 - p0 >= (int32_t ) n_swa) {
1191+ return true ;
1192+ }
1193+ } break ;
1194+ case LLAMA_SWA_TYPE_CHUNKED:
1195+ {
1196+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1197+
1198+ if (p0 < pos_chunk_start) {
1199+ return true ;
1200+ }
1201+ } break ;
1202+ }
1203+
1204+ return false ;
1205+ }
1206+
11941207void llama_kv_cache_unified::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
11951208 std::vector<std::pair<uint32_t , uint32_t >> cell_ranges; // ranges, from inclusive, to exclusive
11961209 uint32_t cell_count = 0 ;
@@ -1586,11 +1599,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15861599
15871600 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
15881601
1589- kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1602+ kv_base = std::make_unique<llama_kv_cache_unified>(
1603+ model, std::move (filter_base), type_k, type_v,
1604+ v_trans, offload, kv_size_base, padding,
1605+ 0 , LLAMA_SWA_TYPE_NONE);
15901606
15911607 LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, kv_size_swa);
15921608
1593- kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1609+ kv_swa = std::make_unique<llama_kv_cache_unified>(
1610+ model, std::move (filter_swa), type_k, type_v,
1611+ v_trans, offload, kv_size_swa, padding,
1612+ hparams.n_swa , hparams.swa_type );
15941613}
15951614
15961615void llama_kv_cache_unified_iswa::clear () {
0 commit comments