1515// llama_kv_cache_unified
1616//
1717
18- llama_kv_cache_unified::llama_kv_cache_unified (const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
19- }
18+ llama_kv_cache_unified::llama_kv_cache_unified (
19+ const llama_hparams & hparams,
20+ callbacks cbs,
21+ ggml_type type_k,
22+ ggml_type type_v,
23+ bool v_trans,
24+ uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) {
2025
21- bool llama_kv_cache_unified::init (
22- const llama_model & model,
23- const llama_cparams & cparams,
24- ggml_type type_k,
25- ggml_type type_v,
26- uint32_t kv_size,
27- bool offload) {
2826 const int32_t n_layer = hparams.n_layer ;
2927
3028 has_shift = false ;
3129
32- GGML_ASSERT (!llama_model_is_recurrent (&model));
33-
34- v_trans = !cparams.flash_attn ;
3530 can_shift = true ;
3631
37- LLAMA_LOG_INFO (" %s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n " ,
38- __func__, kv_size, offload, ggml_type_name (type_k), ggml_type_name (type_v), n_layer, can_shift);
32+ LLAMA_LOG_INFO (" %s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n " ,
33+ __func__, kv_size, ggml_type_name (type_k), ggml_type_name (type_v), n_layer, can_shift);
3934
4035 head = 0 ;
4136 size = kv_size;
@@ -79,25 +74,11 @@ bool llama_kv_cache_unified::init(
7974 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
8075 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
8176
82- const char * dev_name = " CPU" ;
83-
84- ggml_backend_buffer_type_t buft;
85- if (offload) {
86- auto * dev = model.dev_layer (i);
87- buft = ggml_backend_dev_buffer_type (dev);
88-
89- dev_name = ggml_backend_dev_name (dev);
90- } else {
91- buft = ggml_backend_cpu_buffer_type ();
92- }
93-
94- LLAMA_LOG_DEBUG (" %s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n " , __func__,
95- i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
77+ ggml_backend_buffer_type_t buft = cbs.get_buft (i);
9678
9779 ggml_context * ctx = ctx_for_buft (buft);
9880 if (!ctx) {
99- LLAMA_LOG_ERROR (" %s: failed to create ggml context for kv cache\n " , __func__);
100- return false ;
81+ throw std::runtime_error (" failed to create ggml context for kv cache" );
10182 }
10283
10384 ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, n_embd_k_gqa*kv_size);
@@ -115,15 +96,12 @@ bool llama_kv_cache_unified::init(
11596
11697 ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
11798 if (!buf) {
118- LLAMA_LOG_ERROR (" %s: failed to allocate buffer for kv cache\n " , __func__);
119- return false ;
99+ throw std::runtime_error (" failed to allocate buffer for kv cache" );
120100 }
121101 ggml_backend_buffer_clear (buf, 0 );
122102 LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
123103 bufs.emplace_back (buf);
124104 }
125-
126- return true ;
127105}
128106
129107int32_t llama_kv_cache_unified::get_n_tokens () const {
@@ -480,7 +458,7 @@ bool llama_kv_cache_unified::find_slot(
480458 return true ;
481459}
482460
483- uint32_t llama_kv_cache_unified::get_padding (const llama_cparams & cparams) const {
461+ uint32_t llama_kv_cache_unified::get_padding (const llama_cparams & cparams) {
484462 // the FA kernels require padding to avoid extra runtime boundary checks
485463 return cparams.flash_attn ? 256u : 32u ;
486464}
@@ -1021,24 +999,16 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1021999// llama_kv_cache_recurrent
10221000//
10231001
1024- llama_kv_cache_recurrent::llama_kv_cache_recurrent (const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1025- }
1026-
1027- bool llama_kv_cache_recurrent::init (
1028- const llama_model & model,
1029- const llama_cparams & cparams,
1030- ggml_type type_k,
1031- ggml_type type_v,
1032- uint32_t kv_size,
1033- bool offload) {
1034- GGML_UNUSED (cparams);
1035-
1002+ llama_kv_cache_recurrent::llama_kv_cache_recurrent (
1003+ const llama_hparams & hparams,
1004+ callbacks cbs,
1005+ ggml_type type_k,
1006+ ggml_type type_v,
1007+ uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)) {
10361008 const int32_t n_layer = hparams.n_layer ;
10371009
1038- GGML_ASSERT (llama_model_is_recurrent (&model));
1039-
1040- LLAMA_LOG_INFO (" %s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n " ,
1041- __func__, kv_size, offload, ggml_type_name (type_k), ggml_type_name (type_v), n_layer);
1010+ LLAMA_LOG_INFO (" %s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n " ,
1011+ __func__, kv_size, ggml_type_name (type_k), ggml_type_name (type_v), n_layer);
10421012
10431013 head = 0 ;
10441014 size = kv_size;
@@ -1082,25 +1052,11 @@ bool llama_kv_cache_recurrent::init(
10821052 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
10831053 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
10841054
1085- const char * dev_name = " CPU" ;
1086-
1087- ggml_backend_buffer_type_t buft;
1088- if (offload) {
1089- auto * dev = model.dev_layer (i);
1090- buft = ggml_backend_dev_buffer_type (dev);
1091-
1092- dev_name = ggml_backend_dev_name (dev);
1093- } else {
1094- buft = ggml_backend_cpu_buffer_type ();
1095- }
1096-
1097- LLAMA_LOG_DEBUG (" %s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n " , __func__,
1098- i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
1055+ ggml_backend_buffer_type_t buft = cbs.get_buft (i);
10991056
11001057 ggml_context * ctx = ctx_for_buft (buft);
11011058 if (!ctx) {
1102- LLAMA_LOG_ERROR (" %s: failed to create ggml context for kv cache\n " , __func__);
1103- return false ;
1059+ throw std::runtime_error (" failed to create ggml context for kv cache" );
11041060 }
11051061
11061062 ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, n_embd_k_gqa*kv_size);
@@ -1118,15 +1074,12 @@ bool llama_kv_cache_recurrent::init(
11181074
11191075 ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
11201076 if (!buf) {
1121- LLAMA_LOG_ERROR (" %s: failed to allocate buffer for kv cache\n " , __func__);
1122- return false ;
1077+ throw std::runtime_error (" failed to allocate buffer for kv cache" );
11231078 }
11241079 ggml_backend_buffer_clear (buf, 0 );
11251080 LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
11261081 bufs.emplace_back (buf);
11271082 }
1128-
1129- return true ;
11301083}
11311084
11321085int32_t llama_kv_cache_recurrent::get_n_tokens () const {
@@ -1558,11 +1511,6 @@ bool llama_kv_cache_recurrent::find_slot(
15581511 return n >= n_seqs;
15591512}
15601513
1561- uint32_t llama_kv_cache_recurrent::get_padding (const llama_cparams & cparams) const {
1562- // the FA kernels require padding to avoid extra runtime boundary checks
1563- return cparams.flash_attn ? 256u : 32u ;
1564- }
1565-
15661514uint32_t llama_kv_cache_recurrent::cell_max () const {
15671515 for (uint32_t i = size; i > 0 ; --i) {
15681516 const llama_kv_cell & cell = cells[i - 1 ];
0 commit comments