@@ -20,6 +20,8 @@ llama_context::llama_context(
2020 model (model),
2121 t_start_us(model.t_start_us),
2222 t_load_us (model.t_load_us) {
23+ LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
24+
2325 const auto & hparams = model.hparams ;
2426
2527 cparams.n_seq_max = std::max (1u , params.n_seq_max );
@@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
16331635 const llama_context_params & params) :
16341636 llama_context(model, params),
16351637 kv_self(model.hparams) {
1638+ LLAMA_LOG_INFO (" %s: constructing llama_context_kv_self\n " , __func__);
1639+
16361640 const auto & hparams = model.hparams ;
16371641
16381642 LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
@@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
17001704 inp_KQ_mask_swa_cnv = nullptr ;
17011705 inp_KQ_mask_cross = nullptr ;
17021706 inp_k_shift = nullptr ;
1703- inp_s_copy = nullptr ;
1704- inp_s_mask = nullptr ;
17051707 inp_embd_enc = nullptr ;
17061708 inp_pos_bucket = nullptr ;
17071709
@@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
23812383 }
23822384 }
23832385
2384- if (kv_self.recurrent ) {
2385- const int64_t n_kv = kv_self.n ;
2386-
2387- if (inp_s_mask) {
2388- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
2389- float * data = (float *) inp_s_mask->data ;
2390-
2391- // clear unused states
2392- for (int i = 0 ; i < n_kv; ++i) {
2393- const uint32_t cell_id = i + kv_self.head ;
2394- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2395-
2396- data[i] = (float ) (kv_cell.src >= 0 );
2397-
2398- // TODO: do not mutate the KV cache
2399- // only clear once
2400- if (kv_cell.src < 0 ) {
2401- kv_cell.src = cell_id;
2402- }
2403- }
2404- }
2405-
2406- if (inp_s_copy) {
2407- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
2408- int32_t * data = (int32_t *) inp_s_copy->data ;
2409-
2410- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
2411- for (uint32_t i = 0 ; i < n_kv; ++i) {
2412- const uint32_t cell_id = i + kv_self.head ;
2413- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2414-
2415- // prevent out-of-bound sources
2416- if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
2417- kv_cell.src = cell_id;
2418- }
2419-
2420- data[i] = kv_cell.src ;
2421-
2422- // TODO: do not mutate the KV cache
2423- // ensure copy only happens once
2424- if (kv_cell.src != (int32_t ) cell_id) {
2425- kv_cell.src = cell_id;
2426- }
2427- }
2428- }
2429- }
2430-
24312386 if (inp_pos_bucket) {
24322387 const int64_t n_tokens = ubatch.n_tokens ;
24332388
@@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
26142569
26152570void llama_context_kv_self::build_attn_kv_store (
26162571 ggml_context * ctx0,
2617- ggml_cgraph * graph ,
2572+ ggml_cgraph * gf ,
26182573 ggml_tensor * k_cur,
26192574 ggml_tensor * v_cur,
26202575 int32_t n_tokens,
@@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
26352590 // cb(k_cache_view, "k_cache_view", il);
26362591
26372592 // note: storing RoPE-ed version of K in the KV cache
2638- ggml_build_forward_expand (graph , ggml_cpy (ctx0, k_cur, k_cache_view));
2593+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, k_cur, k_cache_view));
26392594
26402595 assert (v_cur->ne [0 ] == n_embd_v_gqa && v_cur->ne [1 ] == n_tokens);
26412596
@@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
26532608 }
26542609 // cb(v_cache_view, "v_cache_view", il);
26552610
2656- ggml_build_forward_expand (graph , ggml_cpy (ctx0, v_cur, v_cache_view));
2611+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, v_cur, v_cache_view));
26572612}
26582613
26592614ggml_tensor * llama_context_kv_self::build_attn_qkv (
26602615 ggml_context * ctx0,
2661- ggml_cgraph * graph ,
2616+ ggml_cgraph * gf ,
26622617 ggml_tensor * wo,
26632618 ggml_tensor * wo_b,
26642619 ggml_tensor * q_cur,
@@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
27912746 }
27922747 }
27932748
2794- ggml_build_forward_expand (graph , cur);
2749+ ggml_build_forward_expand (gf , cur);
27952750
27962751 if (wo) {
27972752 cur = build_lora_mm (ctx0, wo, cur);
@@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
31523107 return inp_KQ_mask_cross;
31533108}
31543109
3155- ggml_tensor * llama_context_kv_self::build_inp_s_copy (
3110+ //
3111+ // llama_context_recurrent
3112+ //
3113+
3114+ llama_context_recurrent::llama_context_recurrent (
3115+ const llama_model & model,
3116+ const llama_context_params & params) :
3117+ llama_context_kv_self(model, params) {
3118+ LLAMA_LOG_INFO (" %s: constructing llama_context_recurrent\n " , __func__);
3119+ }
3120+
3121+ llama_context_recurrent::~llama_context_recurrent () = default ;
3122+
3123+ ggml_cgraph * llama_context_recurrent::graph_init () {
3124+ inp_s_copy = nullptr ;
3125+ inp_s_mask = nullptr ;
3126+
3127+ return llama_context_kv_self::graph_init ();
3128+ }
3129+
3130+ void llama_context_recurrent::input_set (const llama_ubatch & ubatch) {
3131+ // call base functionality
3132+ llama_context_kv_self::input_set (ubatch);
3133+
3134+ GGML_ASSERT (kv_self.recurrent );
3135+
3136+ const int64_t n_kv = kv_self.n ;
3137+
3138+ if (inp_s_mask) {
3139+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
3140+ float * data = (float *) inp_s_mask->data ;
3141+
3142+ // clear unused states
3143+ for (int i = 0 ; i < n_kv; ++i) {
3144+ const uint32_t cell_id = i + kv_self.head ;
3145+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3146+
3147+ data[i] = (float ) (kv_cell.src >= 0 );
3148+
3149+ // TODO: do not mutate the KV cache
3150+ // only clear once
3151+ if (kv_cell.src < 0 ) {
3152+ kv_cell.src = cell_id;
3153+ }
3154+ }
3155+ }
3156+
3157+ if (inp_s_copy) {
3158+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
3159+ int32_t * data = (int32_t *) inp_s_copy->data ;
3160+
3161+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
3162+ for (uint32_t i = 0 ; i < n_kv; ++i) {
3163+ const uint32_t cell_id = i + kv_self.head ;
3164+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3165+
3166+ // prevent out-of-bound sources
3167+ if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
3168+ kv_cell.src = cell_id;
3169+ }
3170+
3171+ data[i] = kv_cell.src ;
3172+
3173+ // TODO: do not mutate the KV cache
3174+ // ensure copy only happens once
3175+ if (kv_cell.src != (int32_t ) cell_id) {
3176+ kv_cell.src = cell_id;
3177+ }
3178+ }
3179+ }
3180+ }
3181+
3182+ ggml_tensor * llama_context_recurrent::build_inp_s_copy (
31563183 ggml_context * ctx0,
31573184 bool worst_case) {
31583185 const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
31633190 return inp_s_copy;
31643191}
31653192
3166- ggml_tensor * llama_context_kv_self ::build_inp_s_mask (
3193+ ggml_tensor * llama_context_recurrent ::build_inp_s_mask (
31673194 ggml_context * ctx0,
31683195 bool worst_case) {
31693196 const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
31733200 return inp_s_mask;
31743201}
31753202
3176- ggml_tensor * llama_context_kv_self ::build_copy_mask_state (
3203+ ggml_tensor * llama_context_recurrent ::build_copy_mask_state (
31773204 ggml_context * ctx0,
31783205 ggml_cgraph * gf,
31793206 ggml_tensor * s,
@@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
32083235}
32093236
32103237// TODO: split
3211- ggml_tensor * llama_context_kv_self ::build_mamba_layer (
3238+ ggml_tensor * llama_context_recurrent ::build_mamba_layer (
32123239 ggml_context * ctx0,
32133240 ggml_cgraph * gf,
32143241 ggml_tensor * cur,
@@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
33443371}
33453372
33463373
3347- ggml_tensor * llama_context_kv_self ::build_rwkv_token_shift_load (
3374+ ggml_tensor * llama_context_recurrent ::build_rwkv_token_shift_load (
33483375 ggml_context * ctx0,
33493376 ggml_cgraph * gf,
33503377 ggml_tensor * state_copy,
@@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
33703397 return token_shift;
33713398}
33723399
3373-
3374- ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store (
3400+ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store (
33753401 ggml_context * ctx0,
33763402 ggml_tensor * token_shift,
33773403 const llama_ubatch & ubatch,
@@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
33943420 );
33953421}
33963422
3397-
3398- ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix (
3423+ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix (
33993424 ggml_context * ctx0,
34003425 ggml_cgraph * gf,
34013426 ggml_tensor * cur,
0 commit comments