Skip to content

Commit cefd037

Browse files
committed
kv-cache : simplify SWA logic
ggml-ci
1 parent 65eee87 commit cefd037

File tree

5 files changed

+73
-45
lines changed

5 files changed

+73
-45
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362362

363363
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364364
if (self_kq_mask) {
365-
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
365+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366366
}
367367
}
368368

369369
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370370
if (self_kq_mask) {
371-
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
371+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372372
}
373373

374374
if (self_kq_mask_swa) {
375-
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn, true);
375+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376376
}
377377
}
378378

src/llama-hparams.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ enum llama_expert_gating_func_type {
1515
};
1616

1717
enum llama_swa_type {
18-
LLAMA_SWA_TYPE_STANDARD = 0,
19-
LLAMA_SWA_TYPE_CHUNKED = 1,
18+
LLAMA_SWA_TYPE_NONE = 0,
19+
LLAMA_SWA_TYPE_STANDARD = 1,
20+
LLAMA_SWA_TYPE_CHUNKED = 2,
2021
};
2122

2223
struct llama_hparams_posnet {
@@ -100,7 +101,7 @@ struct llama_hparams {
100101
std::array<int, 4> rope_sections;
101102

102103
// Sliding Window Attention (SWA)
103-
llama_swa_type swa_type = LLAMA_SWA_TYPE_STANDARD;
104+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
104105

105106
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
106107
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention

src/llama-kv-cache.cpp

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030
bool v_trans,
3131
bool offload,
3232
uint32_t kv_size,
33-
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33+
uint32_t padding,
34+
uint32_t n_swa,
35+
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
3436
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
3537

3638
this->type_k = type_k;
@@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
640642
return ggml_cpy(ctx, v_cur, v_view);
641643
}
642644

643-
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const {
645+
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
644646
const int64_t n_tokens = ubatch->n_tokens;
645647
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
646648
const int64_t n_seqs = ubatch->n_seqs;
@@ -670,38 +672,23 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
670672
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
671673

672674
for (int i = 0; i < n_kv; ++i) {
673-
float f;
674-
// mask the token if:
675-
if (!cells[i].has_seq_id(seq_id) // not the correct sequence
676-
|| (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
677-
) {
678-
f = -INFINITY;
679-
} else {
680-
if (hparams.use_alibi) {
681-
f = -std::abs(cells[i].pos - pos);
682-
} else {
683-
f = 0.0f;
684-
}
685-
}
675+
bool masked = false;
676+
677+
// mask the token if not the same sequence
678+
masked |= !cells[i].has_seq_id(seq_id);
679+
680+
// mask future tokens
681+
masked |= causal_attn && cells[i].pos > pos;
682+
683+
// apply SWA if any
684+
masked |= is_masked_swa(cells[i].pos, pos);
686685

687-
if (swa) {
688-
// may need to cut off old tokens for sliding window
689-
switch (hparams.swa_type) {
690-
case LLAMA_SWA_TYPE_STANDARD:
691-
{
692-
if (pos - cells[i].pos >= (int32_t) hparams.n_swa) {
693-
f = -INFINITY;
694-
}
695-
} break;
696-
case LLAMA_SWA_TYPE_CHUNKED:
697-
{
698-
const llama_pos pos_chunk_start = (pos / hparams.n_swa) * hparams.n_swa;
699-
700-
if (cells[i].pos < pos_chunk_start) {
701-
f = -INFINITY;
702-
}
703-
} break;
704-
}
686+
float f = 0.0f;
687+
688+
if (masked) {
689+
f = -INFINITY;
690+
} else if (hparams.use_alibi) {
691+
f = -std::abs(cells[i].pos - pos);
705692
}
706693

707694
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -1191,6 +1178,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11911178
return 0;
11921179
}
11931180

1181+
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1182+
switch (swa_type) {
1183+
case LLAMA_SWA_TYPE_NONE:
1184+
{
1185+
} break;
1186+
case LLAMA_SWA_TYPE_STANDARD:
1187+
{
1188+
if (p1 - p0 >= (int32_t) n_swa) {
1189+
return true;
1190+
}
1191+
} break;
1192+
case LLAMA_SWA_TYPE_CHUNKED:
1193+
{
1194+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1195+
1196+
if (p0 < pos_chunk_start) {
1197+
return true;
1198+
}
1199+
} break;
1200+
}
1201+
1202+
return false;
1203+
}
1204+
11941205
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
11951206
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
11961207
uint32_t cell_count = 0;
@@ -1586,11 +1597,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15861597

15871598
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
15881599

1589-
kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1600+
kv_base = std::make_unique<llama_kv_cache_unified>(
1601+
model, std::move(filter_base), type_k, type_v,
1602+
v_trans, offload, kv_size_base, padding,
1603+
0, LLAMA_SWA_TYPE_NONE);
15901604

15911605
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, kv_size_swa);
15921606

1593-
kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1607+
kv_swa = std::make_unique<llama_kv_cache_unified>(
1608+
model, std::move(filter_swa), type_k, type_v,
1609+
v_trans, offload, kv_size_swa, padding,
1610+
hparams.n_swa, hparams.swa_type);
15941611
}
15951612

15961613
void llama_kv_cache_unified_iswa::clear() {
@@ -2801,5 +2818,4 @@ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
28012818
void llama_kv_cache_view_update(llama_kv_cache_view * , const llama_kv_cache * ) {
28022819
// TODO: will be removed soon, keep this for now to avoid too many changes in
28032820
// https://github.com/ggml-org/llama.cpp/pull/13194
2804-
GGML_ABORT("not implemented");
28052821
}

src/llama-kv-cache.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
102102
bool v_trans,
103103
bool offload,
104104
uint32_t kv_size,
105-
uint32_t padding);
105+
uint32_t padding,
106+
uint32_t n_swa,
107+
llama_swa_type swa_type);
106108

107109
~llama_kv_cache_unified() = default;
108110

@@ -169,7 +171,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
169171
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
170172
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
171173

172-
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const;
174+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
173175
void set_input_k_shift (ggml_tensor * dst) const;
174176
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
175177

@@ -223,6 +225,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
223225
ggml_type type_k = GGML_TYPE_F16;
224226
ggml_type type_v = GGML_TYPE_F16;
225227

228+
// SWA
229+
uint32_t n_swa = 0;
230+
231+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
232+
226233
std::vector<ggml_context_ptr> ctxs;
227234
std::vector<ggml_backend_buffer_ptr> bufs;
228235

@@ -264,6 +271,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
264271
size_t size_k_bytes() const;
265272
size_t size_v_bytes() const;
266273

274+
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
275+
267276
ggml_tensor * build_rope_shift(
268277
const llama_cparams & cparams,
269278
ggml_context * ctx,

src/llama-model.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13228,7 +13228,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1322813228
!cparams.flash_attn,
1322913229
cparams.offload_kqv,
1323013230
cparams.n_ctx,
13231-
padding);
13231+
padding,
13232+
hparams.n_swa,
13233+
hparams.swa_type);
1323213234
}
1323313235
}
1323413236
}

0 commit comments

Comments
 (0)