Skip to content

Commit d691ff8

Browse files
committed
kv-cache : initial iSWA implementation
ggml-ci
1 parent 621986d commit d691ff8

File tree

7 files changed

+679
-275
lines changed

7 files changed

+679
-275
lines changed

src/llama-graph.cpp

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362362

363363
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364364
if (self_kq_mask) {
365-
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
365+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
366+
}
367+
}
368+
369+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370+
if (self_kq_mask) {
371+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
366372
}
367373

368374
if (self_kq_mask_swa) {
369-
kv_self->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
375+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn, true);
370376
}
371377
}
372378

@@ -416,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
416422
n_layer (hparams.n_layer),
417423
n_rot (hparams.n_rot),
418424
n_ctx (cparams.n_ctx),
419-
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
420425
n_head (hparams.n_head()),
421426
n_head_kv (hparams.n_head_kv()),
422427
n_embd_head_k (hparams.n_embd_head_k),
@@ -1231,6 +1236,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12311236
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12321237

12331238
{
1239+
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
1240+
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
1241+
12341242
const auto n_kv = kv_self->get_n();
12351243

12361244
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -1240,10 +1248,79 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12401248
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
12411249
}
12421250

1243-
if (hparams.n_swa_pattern > 1) {
1244-
GGML_ASSERT(hparams.n_swa > 0);
1251+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1252+
}
12451253

1246-
const auto n_kv = kv_self->get_n();
1254+
ggml_tensor * llm_graph_context::build_attn(
1255+
llm_graph_input_attn_kv_unified * inp,
1256+
ggml_cgraph * gf,
1257+
ggml_tensor * wo,
1258+
ggml_tensor * wo_b,
1259+
ggml_tensor * q_cur,
1260+
ggml_tensor * k_cur,
1261+
ggml_tensor * v_cur,
1262+
ggml_tensor * kq_b,
1263+
ggml_tensor * v_mla,
1264+
float kq_scale,
1265+
int il) const {
1266+
// these nodes are added to the graph together so that they are not reordered
1267+
// by doing so, the number of splits in the graph is reduced
1268+
ggml_build_forward_expand(gf, q_cur);
1269+
ggml_build_forward_expand(gf, k_cur);
1270+
ggml_build_forward_expand(gf, v_cur);
1271+
1272+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1273+
1274+
// store to KV cache
1275+
{
1276+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1277+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1278+
}
1279+
1280+
const auto & kq_mask = inp->get_kq_mask();
1281+
1282+
ggml_tensor * q = q_cur;
1283+
ggml_tensor * k = kv_self->get_k(ctx0, il);
1284+
ggml_tensor * v = kv_self->get_v(ctx0, il);
1285+
1286+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1287+
cb(cur, "kqv_out", il);
1288+
1289+
if (wo) {
1290+
cur = build_lora_mm(wo, cur);
1291+
}
1292+
1293+
if (wo_b) {
1294+
//cb(cur, "kqv_wo", il);
1295+
}
1296+
1297+
if (wo_b) {
1298+
cur = ggml_add(ctx0, cur, wo_b);
1299+
}
1300+
1301+
return cur;
1302+
}
1303+
1304+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1305+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1306+
1307+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1308+
1309+
{
1310+
const auto n_kv = kv_self->get_kv_base()->get_n();
1311+
1312+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1313+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1314+
ggml_set_input(inp->self_kq_mask);
1315+
1316+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1317+
}
1318+
1319+
{
1320+
GGML_ASSERT(hparams.n_swa_pattern > 1 && "Use llama_kv_cache_unified for non-SWA");
1321+
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
1322+
1323+
const auto n_kv = kv_self->get_kv_swa()->get_n();
12471324

12481325
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12491326
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1252,11 +1329,11 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12521329
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
12531330
}
12541331

1255-
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1332+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
12561333
}
12571334

12581335
ggml_tensor * llm_graph_context::build_attn(
1259-
llm_graph_input_attn_kv_unified * inp,
1336+
llm_graph_input_attn_kv_unified_iswa * inp,
12601337
ggml_cgraph * gf,
12611338
ggml_tensor * wo,
12621339
ggml_tensor * wo_b,
@@ -1273,21 +1350,23 @@ ggml_tensor * llm_graph_context::build_attn(
12731350
ggml_build_forward_expand(gf, k_cur);
12741351
ggml_build_forward_expand(gf, v_cur);
12751352

1276-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1353+
const bool is_swa = hparams.is_swa(il);
1354+
1355+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1356+
1357+
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
12771358

12781359
// store to KV cache
12791360
{
1280-
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1281-
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1361+
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1362+
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
12821363
}
12831364

1284-
const bool is_swa = hparams.is_swa(il);
1285-
12861365
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
12871366

12881367
ggml_tensor * q = q_cur;
1289-
ggml_tensor * k = kv_self->get_k(ctx0, il);
1290-
ggml_tensor * v = kv_self->get_v(ctx0, il);
1368+
ggml_tensor * k = kv->get_k(ctx0, il);
1369+
ggml_tensor * v = kv->get_v(ctx0, il);
12911370

12921371
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12931372
cb(cur, "kqv_out", il);

src/llama-graph.h

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct llama_cparams;
1919

2020
class llama_memory_i;
2121
class llama_kv_cache_unified;
22+
class llama_kv_cache_unified_iswa;
2223
class llama_kv_cache_recurrent;
2324

2425
// certain models (typically multi-modal) can produce different types of graphs
@@ -255,6 +256,31 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
255256

256257
void set_input(const llama_ubatch * ubatch) override;
257258

259+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260+
261+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263+
264+
const llama_hparams & hparams;
265+
const llama_cparams & cparams;
266+
267+
const llama_kv_cache_unified * kv_self;
268+
};
269+
270+
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
271+
public:
272+
llm_graph_input_attn_kv_unified_iswa(
273+
const llama_hparams & hparams,
274+
const llama_cparams & cparams,
275+
const llama_kv_cache_unified_iswa * kv_self) :
276+
hparams(hparams),
277+
cparams(cparams),
278+
kv_self(kv_self) {
279+
}
280+
~llm_graph_input_attn_kv_unified_iswa() = default;
281+
282+
void set_input(const llama_ubatch * ubatch) override;
283+
258284
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
259285
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
260286

@@ -266,7 +292,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
266292
const llama_hparams & hparams;
267293
const llama_cparams & cparams;
268294

269-
const llama_kv_cache_unified * kv_self;
295+
const llama_kv_cache_unified_iswa * kv_self;
270296
};
271297

272298
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -378,7 +404,6 @@ struct llm_graph_context {
378404
const int64_t n_layer;
379405
const int64_t n_rot;
380406
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
381-
const int64_t n_ctx_per_seq;
382407
const int64_t n_head;
383408
const int64_t n_head_kv;
384409
const int64_t n_embd_head_k;
@@ -545,6 +570,21 @@ struct llm_graph_context {
545570
float kq_scale,
546571
int il) const;
547572

573+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574+
575+
ggml_tensor * build_attn(
576+
llm_graph_input_attn_kv_unified_iswa * inp,
577+
ggml_cgraph * gf,
578+
ggml_tensor * wo,
579+
ggml_tensor * wo_b,
580+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
583+
ggml_tensor * kq_b,
584+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585+
float kq_scale,
586+
int il) const;
587+
548588
llm_graph_input_attn_cross * build_attn_inp_cross() const;
549589

550590
ggml_tensor * build_attn(

src/llama-hparams.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ enum llama_expert_gating_func_type {
1414
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
1515
};
1616

17+
enum llama_swa_type {
18+
LLAMA_SWA_TYPE_STANDARD = 0,
19+
LLAMA_SWA_TYPE_CHUNKED = 1,
20+
};
21+
1722
struct llama_hparams_posnet {
1823
uint32_t n_embd;
1924
uint32_t n_layer;
@@ -35,8 +40,6 @@ struct llama_hparams {
3540
uint32_t n_embd_features = 0;
3641
uint32_t n_layer;
3742
uint32_t n_rot;
38-
uint32_t n_swa = 0; // sliding window attention (SWA)
39-
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
4043
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4144
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
4245
uint32_t n_expert = 0;
@@ -96,6 +99,12 @@ struct llama_hparams {
9699

97100
std::array<int, 4> rope_sections;
98101

102+
// Sliding Window Attention (SWA)
103+
llama_swa_type swa_type = LLAMA_SWA_TYPE_STANDARD;
104+
105+
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
106+
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
107+
99108
// for State Space Models
100109
uint32_t ssm_d_conv = 0;
101110
uint32_t ssm_d_inner = 0;
@@ -116,11 +125,10 @@ struct llama_hparams {
116125
bool causal_attn = true;
117126
bool use_alibi = false;
118127
bool attn_soft_cap = false;
128+
bool use_kq_norm = true;
119129

130+
// llama4
120131
uint32_t n_moe_layer_step = 0;
121-
bool use_kq_norm = true;
122-
uint32_t n_attn_chunk = 0;
123-
// values below seems to be fixed on llama4
124132
uint32_t n_no_rope_layer_step = 4;
125133
uint32_t n_attn_temp_floor_scale = 8192;
126134
float f_attn_temp_scale = 0.1;

0 commit comments

Comments
 (0)