Skip to content

Commit 621986d

Browse files
committed
kv-cache : prepare for SWA
ggml-ci
1 parent b283804 commit 621986d

File tree

5 files changed

+580
-459
lines changed

5 files changed

+580
-459
lines changed

src/llama-graph.cpp

Lines changed: 61 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,6 @@
99
#include <cmath>
1010
#include <cstring>
1111

12-
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13-
// TODO move to hparams if a T5 variant appears that uses a different value
14-
const int64_t max_distance = 128;
15-
16-
if (bidirectional) {
17-
n_buckets >>= 1;
18-
}
19-
20-
const int64_t max_exact = n_buckets >> 1;
21-
22-
int32_t relative_position = x - y;
23-
int32_t relative_bucket = 0;
24-
25-
if (bidirectional) {
26-
relative_bucket += (relative_position > 0) * n_buckets;
27-
relative_position = abs(relative_position);
28-
} else {
29-
relative_position = -std::min<int32_t>(relative_position, 0);
30-
}
31-
32-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35-
36-
return relative_bucket;
37-
}
38-
3912
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
4013
if (ubatch->token) {
4114
const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
11083

11184
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
11285
if (pos_bucket) {
113-
const int64_t n_tokens = ubatch->n_tokens;
114-
115-
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
116-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
117-
118-
int32_t * data = (int32_t *) pos_bucket->data;
119-
120-
const int64_t n_kv = kv_self->n;
121-
122-
for (int h = 0; h < 1; ++h) {
123-
for (int j = 0; j < n_tokens; ++j) {
124-
for (int i = 0; i < n_kv; ++i) {
125-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
126-
}
127-
}
128-
}
86+
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
12987
}
13088
}
13189

@@ -403,99 +361,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403361
}
404362

405363
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
406-
if (self_kq_mask || self_kq_mask_swa) {
407-
const int64_t n_kv = kv_self->n;
408-
const int64_t n_tokens = ubatch->n_tokens;
409-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410-
const int64_t n_seqs = ubatch->n_seqs;
411-
412-
float * data = nullptr;
413-
float * data_swa = nullptr;
414-
415-
if (self_kq_mask) {
416-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417-
data = (float *) self_kq_mask->data;
418-
}
419-
420-
if (self_kq_mask_swa) {
421-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422-
data_swa = (float *) self_kq_mask_swa->data;
423-
}
424-
425-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
426-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428-
// Causal mask:
429-
// xxx-------
430-
// xxxx------
431-
// xxxxx-----
432-
// Non-causal mask:
433-
// xxxxx-----
434-
// xxxxx-----
435-
// xxxxx-----
436-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437-
for (int h = 0; h < 1; ++h) {
438-
for (int s = 0; s < n_seqs; ++s) {
439-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
440-
441-
for (int j = 0; j < n_seq_tokens; ++j) {
442-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
443-
for (int i = 0; i < n_kv; ++i) {
444-
float f;
445-
// mask the token if:
446-
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
447-
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
448-
) {
449-
f = -INFINITY;
450-
} else {
451-
if (hparams.use_alibi) {
452-
f = -std::abs(kv_self->cells[i].pos - pos);
453-
} else {
454-
f = 0.0f;
455-
}
456-
}
457-
458-
if (data) {
459-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460-
}
461-
462-
// may need to cut off old tokens for sliding window
463-
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464-
if (data_swa) {
465-
if (hparams.n_attn_chunk) {
466-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
467-
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
468-
f = -INFINITY;
469-
}
470-
} else {
471-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
472-
f = -INFINITY;
473-
}
474-
}
475-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476-
}
477-
}
478-
}
479-
}
480-
481-
// mask padded tokens
482-
if (data) {
483-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
484-
for (int j = 0; j < n_kv; ++j) {
485-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486-
}
487-
}
488-
}
364+
if (self_kq_mask) {
365+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366+
}
489367

490-
// mask padded tokens
491-
if (data_swa) {
492-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
493-
for (int j = 0; j < n_kv; ++j) {
494-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495-
}
496-
}
497-
}
498-
}
368+
if (self_kq_mask_swa) {
369+
kv_self->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
499370
}
500371
}
501372

@@ -1153,7 +1024,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
11531024

11541025
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
11551026

1156-
const auto n_kv = kv_self->n;
1027+
const auto n_kv = kv_self->get_n();
11571028

11581029
auto & cur = inp->pos_bucket;
11591030

@@ -1188,16 +1059,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11881059
ggml_tensor * kq_b,
11891060
ggml_tensor * kq_mask,
11901061
ggml_tensor * v_mla,
1191-
bool v_trans,
11921062
float kq_scale) const {
1193-
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1194-
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1195-
1196-
//const int64_t n_head = hparams.n_head(il);
1197-
//const int64_t n_head_kv = hparams.n_head_kv(il);
1063+
const bool v_trans = v->nb[1] > v->nb[2];
11981064

1199-
//const auto & n_embd_head_k = hparams.n_embd_head_k;
1200-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
1065+
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1066+
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1067+
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
12011068

12021069
const auto n_tokens = q->ne[1];
12031070
const auto n_head = q->ne[2];
@@ -1336,17 +1203,11 @@ ggml_tensor * llm_graph_context::build_attn(
13361203

13371204
const auto & kq_mask = inp->get_kq_mask();
13381205

1339-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1340-
//cb(q, "q", il);
1341-
1342-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1343-
//cb(k, "k", il);
1344-
1345-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1346-
//cb(k, "v", il);
1347-
1348-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1206+
ggml_tensor * q = q_cur;
1207+
ggml_tensor * k = k_cur;
1208+
ggml_tensor * v = v_cur;
13491209

1210+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13501211
cb(cur, "kqv_out", il);
13511212

13521213
if (wo) {
@@ -1369,17 +1230,21 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
13691230

13701231
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13711232

1372-
const auto n_kv = kv_self->n;
1233+
{
1234+
const auto n_kv = kv_self->get_n();
13731235

1374-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1375-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1376-
ggml_set_input(inp->self_kq_mask);
1236+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1237+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1238+
ggml_set_input(inp->self_kq_mask);
13771239

1378-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1240+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1241+
}
13791242

13801243
if (hparams.n_swa_pattern > 1) {
13811244
GGML_ASSERT(hparams.n_swa > 0);
13821245

1246+
const auto n_kv = kv_self->get_n();
1247+
13831248
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13841249
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13851250
ggml_set_input(inp->self_kq_mask_swa);
@@ -1409,81 +1274,22 @@ ggml_tensor * llm_graph_context::build_attn(
14091274
ggml_build_forward_expand(gf, v_cur);
14101275

14111276
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1412-
const auto & n_ctx = cparams.n_ctx;
1413-
1414-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1415-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1416-
1417-
const auto n_tokens = q_cur->ne[2];
1418-
1419-
const bool v_trans = !cparams.flash_attn;
14201277

14211278
// store to KV cache
14221279
{
1423-
const auto kv_head = kv_self->head;
1424-
1425-
GGML_ASSERT(kv_self->size == n_ctx);
1426-
1427-
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1428-
//cb(k_cache_view, "k_cache_view", il);
1429-
1430-
// note: storing RoPE-ed version of K in the KV cache
1431-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1432-
1433-
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1434-
1435-
ggml_tensor * v_cache_view = nullptr;
1436-
1437-
if (!v_trans) {
1438-
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1439-
} else {
1440-
// note: the V cache is transposed when not using flash attention
1441-
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1442-
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1443-
(kv_head)*ggml_element_size(kv_self->v_l[il]));
1444-
1445-
v_cur = ggml_transpose(ctx0, v_cur);
1446-
}
1447-
//cb(v_cache_view, "v_cache_view", il);
1448-
1449-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1280+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1281+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
14501282
}
14511283

14521284
const bool is_swa = hparams.is_swa(il);
14531285

14541286
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14551287

1456-
const auto n_kv = kv_self->n;
1288+
ggml_tensor * q = q_cur;
1289+
ggml_tensor * k = kv_self->get_k(ctx0, il);
1290+
ggml_tensor * v = kv_self->get_v(ctx0, il);
14571291

1458-
const int64_t n_head_kv = hparams.n_head_kv(il);
1459-
1460-
const auto & n_embd_head_k = hparams.n_embd_head_k;
1461-
const auto & n_embd_head_v = hparams.n_embd_head_v;
1462-
1463-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1464-
//cb(q, "q", il);
1465-
1466-
ggml_tensor * k =
1467-
ggml_view_3d(ctx0, kv_self->k_l[il],
1468-
n_embd_head_k, n_kv, n_head_kv,
1469-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1470-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1471-
0);
1472-
//cb(k, "k", il);
1473-
1474-
ggml_tensor * v = !v_trans ?
1475-
ggml_view_3d(ctx0, kv_self->v_l[il],
1476-
n_embd_head_v, n_kv, n_head_kv,
1477-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1478-
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1479-
0) :
1480-
ggml_view_3d(ctx0, kv_self->v_l[il],
1481-
n_kv, n_embd_head_v, n_head_kv,
1482-
ggml_element_size(kv_self->v_l[il])*n_ctx,
1483-
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1484-
0);
1485-
1486-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1292+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14871293
cb(cur, "kqv_out", il);
14881294

14891295
if (wo) {
@@ -1534,17 +1340,11 @@ ggml_tensor * llm_graph_context::build_attn(
15341340

15351341
const auto & kq_mask = inp->get_kq_mask_cross();
15361342

1537-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1538-
//cb(q, "q", il);
1539-
1540-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1541-
//cb(k, "k", il);
1542-
1543-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1544-
//cb(k, "v", il);
1545-
1546-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1343+
ggml_tensor * q = q_cur;
1344+
ggml_tensor * k = k_cur;
1345+
ggml_tensor * v = v_cur;
15471346

1347+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
15481348
cb(cur, "kqv_out", il);
15491349

15501350
if (wo) {
@@ -1712,3 +1512,30 @@ void llm_graph_context::build_pooling(
17121512

17131513
ggml_build_forward_expand(gf, cur);
17141514
}
1515+
1516+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1517+
// TODO move to hparams if a T5 variant appears that uses a different value
1518+
const int64_t max_distance = 128;
1519+
1520+
if (bidirectional) {
1521+
n_buckets >>= 1;
1522+
}
1523+
1524+
const int64_t max_exact = n_buckets >> 1;
1525+
1526+
int32_t relative_position = x - y;
1527+
int32_t relative_bucket = 0;
1528+
1529+
if (bidirectional) {
1530+
relative_bucket += (relative_position > 0) * n_buckets;
1531+
relative_position = abs(relative_position);
1532+
} else {
1533+
relative_position = -std::min<int32_t>(relative_position, 0);
1534+
}
1535+
1536+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
1537+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
1538+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1539+
1540+
return relative_bucket;
1541+
}

0 commit comments

Comments
 (0)