Skip to content

Commit 8a338c5

Browse files
committed
kv-cache : initial iSWA implementation
ggml-ci
1 parent 016774b commit 8a338c5

File tree

5 files changed

+415
-5
lines changed

5 files changed

+415
-5
lines changed

src/llama-graph.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,22 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
365365
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366366
}
367367

368+
// TODO: remove
368369
if (self_kq_mask_swa) {
369370
kv_self->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
370371
}
371372
}
372373

374+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
375+
if (self_kq_mask) {
376+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
377+
}
378+
379+
if (self_kq_mask_swa) {
380+
kv_self->get_kv_swa()->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
381+
}
382+
}
383+
373384
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
374385
if (cross_kq_mask) {
375386
const int64_t n_enc = cross_kq_mask->ne[0];
@@ -1239,6 +1250,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12391250
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
12401251
}
12411252

1253+
// TODO: remove
12421254
if (hparams.n_swa_pattern > 1) {
12431255
GGML_ASSERT(hparams.n_swa > 0);
12441256

@@ -1306,6 +1318,89 @@ ggml_tensor * llm_graph_context::build_attn(
13061318
return cur;
13071319
}
13081320

1321+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1322+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1323+
1324+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1325+
1326+
{
1327+
const auto n_kv = kv_self->get_kv_base()->get_n();
1328+
1329+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1330+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1331+
ggml_set_input(inp->self_kq_mask);
1332+
1333+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1334+
}
1335+
1336+
{
1337+
GGML_ASSERT(hparams.n_swa_pattern > 1);
1338+
GGML_ASSERT(hparams.n_swa > 0);
1339+
1340+
const auto n_kv = kv_self->get_kv_swa()->get_n();
1341+
1342+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1343+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1344+
ggml_set_input(inp->self_kq_mask_swa);
1345+
1346+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1347+
}
1348+
1349+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1350+
}
1351+
1352+
ggml_tensor * llm_graph_context::build_attn(
1353+
llm_graph_input_attn_kv_unified_iswa * inp,
1354+
ggml_cgraph * gf,
1355+
ggml_tensor * wo,
1356+
ggml_tensor * wo_b,
1357+
ggml_tensor * q_cur,
1358+
ggml_tensor * k_cur,
1359+
ggml_tensor * v_cur,
1360+
ggml_tensor * kq_b,
1361+
ggml_tensor * v_mla,
1362+
float kq_scale,
1363+
int il) const {
1364+
// these nodes are added to the graph together so that they are not reordered
1365+
// by doing so, the number of splits in the graph is reduced
1366+
ggml_build_forward_expand(gf, q_cur);
1367+
ggml_build_forward_expand(gf, k_cur);
1368+
ggml_build_forward_expand(gf, v_cur);
1369+
1370+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1371+
1372+
// store to KV cache
1373+
{
1374+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1375+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1376+
}
1377+
1378+
const bool is_swa = hparams.is_swa(il);
1379+
1380+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1381+
1382+
ggml_tensor * q = q_cur;
1383+
ggml_tensor * k = kv_self->get_k(ctx0, il);
1384+
ggml_tensor * v = kv_self->get_v(ctx0, il);
1385+
1386+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1387+
cb(cur, "kqv_out", il);
1388+
1389+
if (wo) {
1390+
cur = build_lora_mm(wo, cur);
1391+
}
1392+
1393+
if (wo_b) {
1394+
//cb(cur, "kqv_wo", il);
1395+
}
1396+
1397+
if (wo_b) {
1398+
cur = ggml_add(ctx0, cur, wo_b);
1399+
}
1400+
1401+
return cur;
1402+
}
1403+
13091404
llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13101405
auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
13111406

src/llama-graph.h

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct llama_cparams;
1919

2020
class llama_memory_i;
2121
class llama_kv_cache_unified;
22+
class llama_kv_cache_unified_iswa;
2223
class llama_kv_cache_recurrent;
2324

2425
// certain models (typically multi-modal) can produce different types of graphs
@@ -255,6 +256,34 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
255256

256257
void set_input(const llama_ubatch * ubatch) override;
257258

259+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260+
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // TODO: remove
261+
262+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
263+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
264+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] // TODO: remove
265+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] // TODO: remove
266+
267+
const llama_hparams & hparams;
268+
const llama_cparams & cparams;
269+
270+
const llama_kv_cache_unified * kv_self;
271+
};
272+
273+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
274+
public:
275+
llm_graph_input_attn_kv_unified_iswa(
276+
const llama_hparams & hparams,
277+
const llama_cparams & cparams,
278+
const llama_kv_cache_unified_iswa * kv_self) :
279+
hparams(hparams),
280+
cparams(cparams),
281+
kv_self(kv_self) {
282+
}
283+
~llm_graph_input_attn_kv_unified_iswa() = default;
284+
285+
void set_input(const llama_ubatch * ubatch) override;
286+
258287
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
259288
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
260289

@@ -266,7 +295,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
266295
const llama_hparams & hparams;
267296
const llama_cparams & cparams;
268297

269-
const llama_kv_cache_unified * kv_self;
298+
const llama_kv_cache_unified_iswa * kv_self;
270299
};
271300

272301
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -542,6 +571,21 @@ struct llm_graph_context {
542571
float kq_scale,
543572
int il) const;
544573

574+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
575+
576+
ggml_tensor * build_attn(
577+
llm_graph_input_attn_kv_unified_iswa * inp,
578+
ggml_cgraph * gf,
579+
ggml_tensor * wo,
580+
ggml_tensor * wo_b,
581+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
582+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
583+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
584+
ggml_tensor * kq_b,
585+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
586+
float kq_scale,
587+
int il) const;
588+
545589
llm_graph_input_attn_cross * build_attn_inp_cross() const;
546590

547591
ggml_tensor * build_attn(

src/llama-kv-cache.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
906906
const int64_t n_head_kv = hparams.n_head_kv(il);
907907
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
908908

909+
// TODO: move to model.get_freq_factors()
909910
const bool is_swa = hparams.is_swa(il);
910911

911912
// note: the swa rope params could become part of the cparams in the future
@@ -1600,6 +1601,181 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16001601
return true;
16011602
}
16021603

1604+
//
1605+
// llama_kv_cache_unified_iswa
1606+
//
1607+
1608+
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1609+
const llama_model & model,
1610+
ggml_type type_k,
1611+
ggml_type type_v,
1612+
bool v_trans,
1613+
bool offload,
1614+
uint32_t kv_size,
1615+
uint32_t padding) : hparams(model.hparams) {
1616+
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1617+
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1618+
1619+
// TODO: provide from the llama_context
1620+
const uint32_t n_seq_max = 1;
1621+
1622+
const uint32_t kv_size_base = kv_size;
1623+
const uint32_t kv_size_swa = hparams.n_swa*n_seq_max;
1624+
1625+
kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1626+
kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1627+
}
1628+
1629+
void llama_kv_cache_unified_iswa::clear() {
1630+
kv_base->clear();
1631+
kv_swa ->clear();
1632+
}
1633+
1634+
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1635+
bool res = true;
1636+
1637+
res = res & kv_base->seq_rm(seq_id, p0, p1);
1638+
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
1639+
1640+
return res;
1641+
}
1642+
1643+
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1644+
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1645+
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1646+
}
1647+
1648+
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
1649+
kv_base->seq_keep(seq_id);
1650+
kv_swa ->seq_keep(seq_id);
1651+
}
1652+
1653+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1654+
kv_base->seq_add(seq_id, p0, p1, delta);
1655+
kv_swa ->seq_add(seq_id, p0, p1, delta);
1656+
}
1657+
1658+
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1659+
kv_base->seq_div(seq_id, p0, p1, d);
1660+
kv_swa ->seq_div(seq_id, p0, p1, d);
1661+
}
1662+
1663+
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1664+
return kv_base->seq_pos_max(seq_id);
1665+
}
1666+
1667+
void llama_kv_cache_unified_iswa::restore() {
1668+
kv_base->restore();
1669+
kv_swa ->restore();
1670+
}
1671+
1672+
void llama_kv_cache_unified_iswa::commit() {
1673+
kv_base->commit();
1674+
kv_swa ->commit();
1675+
}
1676+
1677+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
1678+
bool res = true;
1679+
1680+
res = res & kv_base->update(lctx);
1681+
res = res & kv_swa ->update(lctx);
1682+
1683+
return res;
1684+
}
1685+
1686+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
1687+
kv_base->defrag_sched(thold);
1688+
kv_swa ->defrag_sched(thold);
1689+
}
1690+
1691+
void llama_kv_cache_unified_iswa::set_full() {
1692+
kv_base->set_full();
1693+
kv_swa ->set_full();
1694+
}
1695+
1696+
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1697+
return kv_base->sbatch_init(batch, logits_all);
1698+
}
1699+
1700+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1701+
return kv_base->ubatch_next(sbatch, n_ubatch, embd_pooled);
1702+
}
1703+
1704+
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1705+
bool res = true;
1706+
1707+
res = res & kv_base->find_slot(batch);
1708+
res = res & kv_swa ->find_slot(batch);
1709+
1710+
return res;
1711+
}
1712+
1713+
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
1714+
return kv_base->get_n_tokens();
1715+
}
1716+
1717+
int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
1718+
return kv_base->get_used_cells();
1719+
}
1720+
1721+
llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
1722+
return kv_base->get_pos_max();
1723+
}
1724+
1725+
bool llama_kv_cache_unified_iswa::get_can_shift() const {
1726+
return false;
1727+
}
1728+
1729+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1730+
kv_base->state_write(io, seq_id);
1731+
kv_swa ->state_write(io, seq_id);
1732+
}
1733+
1734+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1735+
kv_base->state_read(io, seq_id);
1736+
kv_swa ->state_read(io, seq_id);
1737+
}
1738+
1739+
ggml_tensor * llama_kv_cache_unified_iswa::get_k(ggml_context * ctx, int32_t il) const {
1740+
if (hparams.is_swa(il)) {
1741+
return kv_swa->get_k(ctx, il);
1742+
}
1743+
1744+
return kv_base->get_k(ctx, il);
1745+
}
1746+
1747+
ggml_tensor * llama_kv_cache_unified_iswa::get_v(ggml_context * ctx, int32_t il) const {
1748+
if (hparams.is_swa(il)) {
1749+
return kv_swa->get_v(ctx, il);
1750+
}
1751+
1752+
return kv_base->get_v(ctx, il);
1753+
}
1754+
1755+
ggml_tensor * llama_kv_cache_unified_iswa::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1756+
if (hparams.is_swa(il)) {
1757+
return kv_swa->cpy_k(ctx, k_cur, il);
1758+
}
1759+
1760+
return kv_base->cpy_k(ctx, k_cur, il);
1761+
}
1762+
1763+
ggml_tensor * llama_kv_cache_unified_iswa::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1764+
if (hparams.is_swa(il)) {
1765+
return kv_swa->cpy_v(ctx, v_cur, il);
1766+
}
1767+
1768+
return kv_base->cpy_v(ctx, v_cur, il);
1769+
}
1770+
1771+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const {
1772+
return kv_base.get();
1773+
}
1774+
1775+
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
1776+
return kv_swa.get();
1777+
}
1778+
16031779
//
16041780
// llama_kv_cache_recurrent
16051781
//

0 commit comments

Comments
 (0)