99#include < cmath>
1010#include < cstring>
1111
12- static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13- // TODO move to hparams if a T5 variant appears that uses a different value
14- const int64_t max_distance = 128 ;
15-
16- if (bidirectional) {
17- n_buckets >>= 1 ;
18- }
19-
20- const int64_t max_exact = n_buckets >> 1 ;
21-
22- int32_t relative_position = x - y;
23- int32_t relative_bucket = 0 ;
24-
25- if (bidirectional) {
26- relative_bucket += (relative_position > 0 ) * n_buckets;
27- relative_position = abs (relative_position);
28- } else {
29- relative_position = -std::min<int32_t >(relative_position, 0 );
30- }
31-
32- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
33- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
34- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35-
36- return relative_bucket;
37- }
38-
3912void llm_graph_input_embd::set_input (const llama_ubatch * ubatch) {
4013 if (ubatch->token ) {
4114 const int64_t n_tokens = ubatch->n_tokens ;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
11083
11184void llm_graph_input_pos_bucket_kv::set_input (const llama_ubatch * ubatch) {
11285 if (pos_bucket) {
113- const int64_t n_tokens = ubatch->n_tokens ;
114-
115- GGML_ASSERT (ggml_backend_buffer_is_host (pos_bucket->buffer ));
116- GGML_ASSERT (!ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
117-
118- int32_t * data = (int32_t *) pos_bucket->data ;
119-
120- const int64_t n_kv = kv_self->n ;
121-
122- for (int h = 0 ; h < 1 ; ++h) {
123- for (int j = 0 ; j < n_tokens; ++j) {
124- for (int i = 0 ; i < n_kv; ++i) {
125- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self->cells [i].pos , ubatch->pos [j], hparams.n_rel_attn_bkts , false );
126- }
127- }
128- }
86+ kv_self->set_input_pos_bucket (pos_bucket, ubatch);
12987 }
13088}
13189
@@ -403,99 +361,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403361}
404362
405363void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
406- if (self_kq_mask || self_kq_mask_swa) {
407- const int64_t n_kv = kv_self->n ;
408- const int64_t n_tokens = ubatch->n_tokens ;
409- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
410- const int64_t n_seqs = ubatch->n_seqs ;
411-
412- float * data = nullptr ;
413- float * data_swa = nullptr ;
414-
415- if (self_kq_mask) {
416- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
417- data = (float *) self_kq_mask->data ;
418- }
419-
420- if (self_kq_mask_swa) {
421- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
422- data_swa = (float *) self_kq_mask_swa->data ;
423- }
424-
425- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428- // Causal mask:
429- // xxx-------
430- // xxxx------
431- // xxxxx-----
432- // Non-causal mask:
433- // xxxxx-----
434- // xxxxx-----
435- // xxxxx-----
436- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437- for (int h = 0 ; h < 1 ; ++h) {
438- for (int s = 0 ; s < n_seqs; ++s) {
439- const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
440-
441- for (int j = 0 ; j < n_seq_tokens; ++j) {
442- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
443- for (int i = 0 ; i < n_kv; ++i) {
444- float f;
445- // mask the token if:
446- if (!kv_self->cells [i].has_seq_id (seq_id) // not the correct sequence
447- || (cparams.causal_attn && kv_self->cells [i].pos > pos) // for causal, mask future tokens
448- ) {
449- f = -INFINITY;
450- } else {
451- if (hparams.use_alibi ) {
452- f = -std::abs (kv_self->cells [i].pos - pos);
453- } else {
454- f = 0 .0f ;
455- }
456- }
457-
458- if (data) {
459- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460- }
461-
462- // may need to cut off old tokens for sliding window
463- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464- if (data_swa) {
465- if (hparams.n_attn_chunk ) {
466- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk ) * hparams.n_attn_chunk ;
467- if (kv_self->cells [i].pos < pos_chunk_start || pos < pos_chunk_start) {
468- f = -INFINITY;
469- }
470- } else {
471- if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
472- f = -INFINITY;
473- }
474- }
475- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476- }
477- }
478- }
479- }
480-
481- // mask padded tokens
482- if (data) {
483- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
484- for (int j = 0 ; j < n_kv; ++j) {
485- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486- }
487- }
488- }
364+ if (self_kq_mask) {
365+ kv_self->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
366+ }
489367
490- // mask padded tokens
491- if (data_swa) {
492- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
493- for (int j = 0 ; j < n_kv; ++j) {
494- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495- }
496- }
497- }
498- }
368+ if (self_kq_mask_swa) {
369+ kv_self->set_input_kq_mask_swa (self_kq_mask_swa, ubatch, cparams.causal_attn );
499370 }
500371}
501372
@@ -1153,7 +1024,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
11531024
11541025 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
11551026
1156- const auto n_kv = kv_self->n ;
1027+ const auto n_kv = kv_self->get_n () ;
11571028
11581029 auto & cur = inp->pos_bucket ;
11591030
@@ -1188,16 +1059,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11881059 ggml_tensor * kq_b,
11891060 ggml_tensor * kq_mask,
11901061 ggml_tensor * v_mla,
1191- bool v_trans,
11921062 float kq_scale) const {
1193- // const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1194- // const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1195-
1196- // const int64_t n_head = hparams.n_head(il);
1197- // const int64_t n_head_kv = hparams.n_head_kv(il);
1063+ const bool v_trans = v->nb [1 ] > v->nb [2 ];
11981064
1199- // const auto & n_embd_head_k = hparams.n_embd_head_k;
1200- // const auto & n_embd_head_v = hparams.n_embd_head_v;
1065+ q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
1066+ k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
1067+ v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
12011068
12021069 const auto n_tokens = q->ne [1 ];
12031070 const auto n_head = q->ne [2 ];
@@ -1336,17 +1203,11 @@ ggml_tensor * llm_graph_context::build_attn(
13361203
13371204 const auto & kq_mask = inp->get_kq_mask ();
13381205
1339- ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1340- // cb(q, "q", il);
1341-
1342- ggml_tensor * k = ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 );
1343- // cb(k, "k", il);
1344-
1345- ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1346- // cb(k, "v", il);
1347-
1348- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, false , kq_scale);
1206+ ggml_tensor * q = q_cur;
1207+ ggml_tensor * k = k_cur;
1208+ ggml_tensor * v = v_cur;
13491209
1210+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13501211 cb (cur, " kqv_out" , il);
13511212
13521213 if (wo) {
@@ -1369,17 +1230,21 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
13691230
13701231 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13711232
1372- const auto n_kv = kv_self->n ;
1233+ {
1234+ const auto n_kv = kv_self->get_n ();
13731235
1374- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1375- // cb(inp->self_kq_mask, "KQ_mask", -1);
1376- ggml_set_input (inp->self_kq_mask );
1236+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1237+ // cb(inp->self_kq_mask, "KQ_mask", -1);
1238+ ggml_set_input (inp->self_kq_mask );
13771239
1378- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1240+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1241+ }
13791242
13801243 if (hparams.n_swa_pattern > 1 ) {
13811244 GGML_ASSERT (hparams.n_swa > 0 );
13821245
1246+ const auto n_kv = kv_self->get_n ();
1247+
13831248 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13841249 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13851250 ggml_set_input (inp->self_kq_mask_swa );
@@ -1409,81 +1274,22 @@ ggml_tensor * llm_graph_context::build_attn(
14091274 ggml_build_forward_expand (gf, v_cur);
14101275
14111276 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1412- const auto & n_ctx = cparams.n_ctx ;
1413-
1414- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
1415- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
1416-
1417- const auto n_tokens = q_cur->ne [2 ];
1418-
1419- const bool v_trans = !cparams.flash_attn ;
14201277
14211278 // store to KV cache
14221279 {
1423- const auto kv_head = kv_self->head ;
1424-
1425- GGML_ASSERT (kv_self->size == n_ctx);
1426-
1427- ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self->k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa)*kv_head);
1428- // cb(k_cache_view, "k_cache_view", il);
1429-
1430- // note: storing RoPE-ed version of K in the KV cache
1431- ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view));
1432-
1433- v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens);
1434-
1435- ggml_tensor * v_cache_view = nullptr ;
1436-
1437- if (!v_trans) {
1438- v_cache_view = ggml_view_1d (ctx0, kv_self->v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa)*kv_head);
1439- } else {
1440- // note: the V cache is transposed when not using flash attention
1441- v_cache_view = ggml_view_2d (ctx0, kv_self->v_l [il], n_tokens, n_embd_v_gqa,
1442- ( n_ctx)*ggml_element_size (kv_self->v_l [il]),
1443- (kv_head)*ggml_element_size (kv_self->v_l [il]));
1444-
1445- v_cur = ggml_transpose (ctx0, v_cur);
1446- }
1447- // cb(v_cache_view, "v_cache_view", il);
1448-
1449- ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
1280+ ggml_build_forward_expand (gf, kv_self->cpy_k (ctx0, k_cur, il));
1281+ ggml_build_forward_expand (gf, kv_self->cpy_v (ctx0, v_cur, il));
14501282 }
14511283
14521284 const bool is_swa = hparams.is_swa (il);
14531285
14541286 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
14551287
1456- const auto n_kv = kv_self->n ;
1288+ ggml_tensor * q = q_cur;
1289+ ggml_tensor * k = kv_self->get_k (ctx0, il);
1290+ ggml_tensor * v = kv_self->get_v (ctx0, il);
14571291
1458- const int64_t n_head_kv = hparams.n_head_kv (il);
1459-
1460- const auto & n_embd_head_k = hparams.n_embd_head_k ;
1461- const auto & n_embd_head_v = hparams.n_embd_head_v ;
1462-
1463- ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1464- // cb(q, "q", il);
1465-
1466- ggml_tensor * k =
1467- ggml_view_3d (ctx0, kv_self->k_l [il],
1468- n_embd_head_k, n_kv, n_head_kv,
1469- ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1470- ggml_row_size (kv_self->k_l [il]->type , n_embd_head_k),
1471- 0 );
1472- // cb(k, "k", il);
1473-
1474- ggml_tensor * v = !v_trans ?
1475- ggml_view_3d (ctx0, kv_self->v_l [il],
1476- n_embd_head_v, n_kv, n_head_kv,
1477- ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1478- ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1479- 0 ) :
1480- ggml_view_3d (ctx0, kv_self->v_l [il],
1481- n_kv, n_embd_head_v, n_head_kv,
1482- ggml_element_size (kv_self->v_l [il])*n_ctx,
1483- ggml_element_size (kv_self->v_l [il])*n_ctx*n_embd_head_v,
1484- 0 );
1485-
1486- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1292+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14871293 cb (cur, " kqv_out" , il);
14881294
14891295 if (wo) {
@@ -1534,17 +1340,11 @@ ggml_tensor * llm_graph_context::build_attn(
15341340
15351341 const auto & kq_mask = inp->get_kq_mask_cross ();
15361342
1537- ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1538- // cb(q, "q", il);
1539-
1540- ggml_tensor * k = ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 );
1541- // cb(k, "k", il);
1542-
1543- ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1544- // cb(k, "v", il);
1545-
1546- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, false , kq_scale);
1343+ ggml_tensor * q = q_cur;
1344+ ggml_tensor * k = k_cur;
1345+ ggml_tensor * v = v_cur;
15471346
1347+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
15481348 cb (cur, " kqv_out" , il);
15491349
15501350 if (wo) {
@@ -1712,3 +1512,30 @@ void llm_graph_context::build_pooling(
17121512
17131513 ggml_build_forward_expand (gf, cur);
17141514}
1515+
1516+ int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1517+ // TODO move to hparams if a T5 variant appears that uses a different value
1518+ const int64_t max_distance = 128 ;
1519+
1520+ if (bidirectional) {
1521+ n_buckets >>= 1 ;
1522+ }
1523+
1524+ const int64_t max_exact = n_buckets >> 1 ;
1525+
1526+ int32_t relative_position = x - y;
1527+ int32_t relative_bucket = 0 ;
1528+
1529+ if (bidirectional) {
1530+ relative_bucket += (relative_position > 0 ) * n_buckets;
1531+ relative_position = abs (relative_position);
1532+ } else {
1533+ relative_position = -std::min<int32_t >(relative_position, 0 );
1534+ }
1535+
1536+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
1537+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
1538+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1539+
1540+ return relative_bucket;
1541+ }
0 commit comments