Skip to content

Commit fe19219

Browse files
committed
cont : apply to all iSWA models
ggml-ci
1 parent 8a338c5 commit fe19219

File tree

6 files changed

+275
-258
lines changed

6 files changed

+275
-258
lines changed

src/llama-graph.cpp

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -362,22 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362362

363363
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364364
if (self_kq_mask) {
365-
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366-
}
367-
368-
// TODO: remove
369-
if (self_kq_mask_swa) {
370-
kv_self->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
365+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
371366
}
372367
}
373368

374369
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
375370
if (self_kq_mask) {
376-
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
371+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
377372
}
378373

379374
if (self_kq_mask_swa) {
380-
kv_self->get_kv_swa()->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
375+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn, true);
381376
}
382377
}
383378

@@ -427,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
427422
n_layer (hparams.n_layer),
428423
n_rot (hparams.n_rot),
429424
n_ctx (cparams.n_ctx),
430-
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
431425
n_head (hparams.n_head()),
432426
n_head_kv (hparams.n_head_kv()),
433427
n_embd_head_k (hparams.n_embd_head_k),
@@ -1241,6 +1235,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12411235
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12421236

12431237
{
1238+
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
1239+
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
1240+
12441241
const auto n_kv = kv_self->get_n();
12451242

12461243
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -1250,19 +1247,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12501247
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
12511248
}
12521249

1253-
// TODO: remove
1254-
if (hparams.n_swa_pattern > 1) {
1255-
GGML_ASSERT(hparams.n_swa > 0);
1256-
1257-
const auto n_kv = kv_self->get_n();
1258-
1259-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1260-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1261-
ggml_set_input(inp->self_kq_mask_swa);
1262-
1263-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1264-
}
1265-
12661250
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
12671251
}
12681252

@@ -1292,9 +1276,7 @@ ggml_tensor * llm_graph_context::build_attn(
12921276
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
12931277
}
12941278

1295-
const bool is_swa = hparams.is_swa(il);
1296-
1297-
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1279+
const auto & kq_mask = inp->get_kq_mask();
12981280

12991281
ggml_tensor * q = q_cur;
13001282
ggml_tensor * k = kv_self->get_k(ctx0, il);
@@ -1334,8 +1316,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13341316
}
13351317

13361318
{
1337-
GGML_ASSERT(hparams.n_swa_pattern > 1);
1338-
GGML_ASSERT(hparams.n_swa > 0);
1319+
GGML_ASSERT(hparams.n_swa_pattern > 1 && "Use llama_kv_cache_unified for non-SWA");
1320+
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
13391321

13401322
const auto n_kv = kv_self->get_kv_swa()->get_n();
13411323

@@ -1367,21 +1349,23 @@ ggml_tensor * llm_graph_context::build_attn(
13671349
ggml_build_forward_expand(gf, k_cur);
13681350
ggml_build_forward_expand(gf, v_cur);
13691351

1352+
const bool is_swa = hparams.is_swa(il);
1353+
13701354
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
13711355

1356+
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1357+
13721358
// store to KV cache
13731359
{
1374-
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1375-
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1360+
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1361+
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
13761362
}
13771363

1378-
const bool is_swa = hparams.is_swa(il);
1379-
13801364
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
13811365

13821366
ggml_tensor * q = q_cur;
1383-
ggml_tensor * k = kv_self->get_k(ctx0, il);
1384-
ggml_tensor * v = kv_self->get_v(ctx0, il);
1367+
ggml_tensor * k = kv->get_k(ctx0, il);
1368+
ggml_tensor * v = kv->get_v(ctx0, il);
13851369

13861370
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13871371
cb(cur, "kqv_out", il);

src/llama-graph.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,9 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257257
void set_input(const llama_ubatch * ubatch) override;
258258

259259
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260-
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // TODO: remove
261260

262261
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
263262
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
264-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] // TODO: remove
265-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] // TODO: remove
266263

267264
const llama_hparams & hparams;
268265
const llama_cparams & cparams;
@@ -404,7 +401,6 @@ struct llm_graph_context {
404401
const int64_t n_layer;
405402
const int64_t n_rot;
406403
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
407-
const int64_t n_ctx_per_seq;
408404
const int64_t n_head;
409405
const int64_t n_head_kv;
410406
const int64_t n_embd_head_k;

src/llama-kv-cache.cpp

Lines changed: 14 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
630630
return ggml_cpy(ctx, v_cur, v_view);
631631
}
632632

633-
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
633+
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const {
634634
const int64_t n_tokens = ubatch->n_tokens;
635635
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
636636
const int64_t n_seqs = ubatch->n_seqs;
@@ -674,68 +674,21 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
674674
}
675675
}
676676

677-
if (data) {
678-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
679-
}
680-
}
681-
}
682-
}
683-
684-
// mask padded tokens
685-
if (data) {
686-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
687-
for (int j = 0; j < n_kv; ++j) {
688-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
689-
}
690-
}
691-
}
692-
}
693-
}
694-
695-
void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
696-
const int64_t n_tokens = ubatch->n_tokens;
697-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
698-
const int64_t n_seqs = ubatch->n_seqs;
699-
700-
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
701-
float * data = (float *) dst->data;
702-
703-
const int64_t n_kv = n;
704-
705-
for (int h = 0; h < 1; ++h) {
706-
for (int s = 0; s < n_seqs; ++s) {
707-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
708-
709-
for (int j = 0; j < n_seq_tokens; ++j) {
710-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
711-
712-
for (int i = 0; i < n_kv; ++i) {
713-
float f;
714-
// mask the token if:
715-
if (!cells[i].has_seq_id(seq_id) // not the correct sequence
716-
|| (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
717-
) {
718-
f = -INFINITY;
719-
} else {
720-
if (hparams.use_alibi) {
721-
f = -std::abs(cells[i].pos - pos);
722-
} else {
723-
f = 0.0f;
677+
if (swa) {
678+
// may need to cut off old tokens for sliding window
679+
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
680+
if (hparams.n_attn_chunk) {
681+
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
682+
if (cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
683+
f = -INFINITY;
684+
}
685+
} else if (hparams.n_swa) {
686+
if (pos - cells[i].pos >= (int32_t) hparams.n_swa) {
687+
f = -INFINITY;
688+
}
724689
}
725690
}
726691

727-
// may need to cut off old tokens for sliding window
728-
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
729-
if (hparams.n_attn_chunk) {
730-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
731-
if (cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
732-
f = -INFINITY;
733-
}
734-
} else {
735-
if (pos - cells[i].pos >= (int32_t)hparams.n_swa) {
736-
f = -INFINITY;
737-
}
738-
}
739692
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
740693
}
741694
}
@@ -891,8 +844,6 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
891844
const auto & n_embd_head_k = hparams.n_embd_head_k;
892845
//const auto & n_embd_head_v = hparams.n_embd_head_v;
893846

894-
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
895-
896847
//GGML_ASSERT(kv_self->size == n_ctx);
897848

898849
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
@@ -914,7 +865,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
914865
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
915866
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
916867

917-
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
868+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
918869

919870
ggml_tensor * k =
920871
ggml_view_3d(ctx, layer.k,
@@ -1736,38 +1687,6 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
17361687
kv_swa ->state_read(io, seq_id);
17371688
}
17381689

1739-
ggml_tensor * llama_kv_cache_unified_iswa::get_k(ggml_context * ctx, int32_t il) const {
1740-
if (hparams.is_swa(il)) {
1741-
return kv_swa->get_k(ctx, il);
1742-
}
1743-
1744-
return kv_base->get_k(ctx, il);
1745-
}
1746-
1747-
ggml_tensor * llama_kv_cache_unified_iswa::get_v(ggml_context * ctx, int32_t il) const {
1748-
if (hparams.is_swa(il)) {
1749-
return kv_swa->get_v(ctx, il);
1750-
}
1751-
1752-
return kv_base->get_v(ctx, il);
1753-
}
1754-
1755-
ggml_tensor * llama_kv_cache_unified_iswa::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1756-
if (hparams.is_swa(il)) {
1757-
return kv_swa->cpy_k(ctx, k_cur, il);
1758-
}
1759-
1760-
return kv_base->cpy_k(ctx, k_cur, il);
1761-
}
1762-
1763-
ggml_tensor * llama_kv_cache_unified_iswa::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1764-
if (hparams.is_swa(il)) {
1765-
return kv_swa->cpy_v(ctx, v_cur, il);
1766-
}
1767-
1768-
return kv_base->cpy_v(ctx, v_cur, il);
1769-
}
1770-
17711690
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
17721691
return kv_base.get();
17731692
}

src/llama-kv-cache.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
168168
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
169169
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
170170

171-
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
172-
void set_input_kq_mask_swa(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; // TODO: remove
173-
171+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const;
174172
void set_input_k_shift (ggml_tensor * dst) const;
175173
void set_input_pos_bucket (ggml_tensor * dst, const llama_ubatch * ubatch) const;
176174

@@ -360,12 +358,6 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
360358
// llama_kv_cache_unified_iswa specific API
361359
//
362360

363-
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
364-
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
365-
366-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
367-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
368-
369361
llama_kv_cache_unified * get_kv_base() const;
370362
llama_kv_cache_unified * get_kv_swa () const;
371363

0 commit comments

Comments
 (0)