@@ -1700,8 +1700,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
17001700 inp_KQ_mask_swa_cnv = nullptr ;
17011701 inp_KQ_mask_cross = nullptr ;
17021702 inp_k_shift = nullptr ;
1703- inp_s_copy = nullptr ;
1704- inp_s_mask = nullptr ;
17051703 inp_embd_enc = nullptr ;
17061704 inp_pos_bucket = nullptr ;
17071705
@@ -2381,53 +2379,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
23812379 }
23822380 }
23832381
2384- if (kv_self.recurrent ) {
2385- const int64_t n_kv = kv_self.n ;
2386-
2387- if (inp_s_mask) {
2388- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
2389- float * data = (float *) inp_s_mask->data ;
2390-
2391- // clear unused states
2392- for (int i = 0 ; i < n_kv; ++i) {
2393- const uint32_t cell_id = i + kv_self.head ;
2394- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2395-
2396- data[i] = (float ) (kv_cell.src >= 0 );
2397-
2398- // TODO: do not mutate the KV cache
2399- // only clear once
2400- if (kv_cell.src < 0 ) {
2401- kv_cell.src = cell_id;
2402- }
2403- }
2404- }
2405-
2406- if (inp_s_copy) {
2407- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
2408- int32_t * data = (int32_t *) inp_s_copy->data ;
2409-
2410- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
2411- for (uint32_t i = 0 ; i < n_kv; ++i) {
2412- const uint32_t cell_id = i + kv_self.head ;
2413- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2414-
2415- // prevent out-of-bound sources
2416- if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
2417- kv_cell.src = cell_id;
2418- }
2419-
2420- data[i] = kv_cell.src ;
2421-
2422- // TODO: do not mutate the KV cache
2423- // ensure copy only happens once
2424- if (kv_cell.src != (int32_t ) cell_id) {
2425- kv_cell.src = cell_id;
2426- }
2427- }
2428- }
2429- }
2430-
24312382 if (inp_pos_bucket) {
24322383 const int64_t n_tokens = ubatch.n_tokens ;
24332384
@@ -2614,7 +2565,7 @@ void llama_context_kv_self::build_attn_inp(
26142565
26152566void llama_context_kv_self::build_attn_kv_store (
26162567 ggml_context * ctx0,
2617- ggml_cgraph * graph ,
2568+ ggml_cgraph * gf ,
26182569 ggml_tensor * k_cur,
26192570 ggml_tensor * v_cur,
26202571 int32_t n_tokens,
@@ -2635,7 +2586,7 @@ void llama_context_kv_self::build_attn_kv_store(
26352586 // cb(k_cache_view, "k_cache_view", il);
26362587
26372588 // note: storing RoPE-ed version of K in the KV cache
2638- ggml_build_forward_expand (graph , ggml_cpy (ctx0, k_cur, k_cache_view));
2589+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, k_cur, k_cache_view));
26392590
26402591 assert (v_cur->ne [0 ] == n_embd_v_gqa && v_cur->ne [1 ] == n_tokens);
26412592
@@ -2653,12 +2604,12 @@ void llama_context_kv_self::build_attn_kv_store(
26532604 }
26542605 // cb(v_cache_view, "v_cache_view", il);
26552606
2656- ggml_build_forward_expand (graph , ggml_cpy (ctx0, v_cur, v_cache_view));
2607+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, v_cur, v_cache_view));
26572608}
26582609
26592610ggml_tensor * llama_context_kv_self::build_attn_qkv (
26602611 ggml_context * ctx0,
2661- ggml_cgraph * graph ,
2612+ ggml_cgraph * gf ,
26622613 ggml_tensor * wo,
26632614 ggml_tensor * wo_b,
26642615 ggml_tensor * q_cur,
@@ -2791,7 +2742,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
27912742 }
27922743 }
27932744
2794- ggml_build_forward_expand (graph , cur);
2745+ ggml_build_forward_expand (gf , cur);
27952746
27962747 if (wo) {
27972748 cur = build_lora_mm (ctx0, wo, cur);
@@ -3152,7 +3103,70 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
31523103 return inp_KQ_mask_cross;
31533104}
31543105
3155- ggml_tensor * llama_context_kv_self::build_inp_s_copy (
3106+ //
3107+ // llama_context_rwkv
3108+ //
3109+
3110+ ggml_cgraph * llama_context_rwkv::graph_init () {
3111+ inp_s_copy = nullptr ;
3112+ inp_s_mask = nullptr ;
3113+
3114+ return llama_context_kv_self::graph_init ();
3115+ }
3116+
3117+ void llama_context_rwkv::input_set (const llama_ubatch & ubatch) {
3118+ // call base functionality
3119+ llama_context_kv_self::input_set (ubatch);
3120+
3121+ GGML_ASSERT (kv_self.recurrent );
3122+
3123+ const int64_t n_kv = kv_self.n ;
3124+
3125+ if (inp_s_mask) {
3126+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
3127+ float * data = (float *) inp_s_mask->data ;
3128+
3129+ // clear unused states
3130+ for (int i = 0 ; i < n_kv; ++i) {
3131+ const uint32_t cell_id = i + kv_self.head ;
3132+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3133+
3134+ data[i] = (float ) (kv_cell.src >= 0 );
3135+
3136+ // TODO: do not mutate the KV cache
3137+ // only clear once
3138+ if (kv_cell.src < 0 ) {
3139+ kv_cell.src = cell_id;
3140+ }
3141+ }
3142+ }
3143+
3144+ if (inp_s_copy) {
3145+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
3146+ int32_t * data = (int32_t *) inp_s_copy->data ;
3147+
3148+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
3149+ for (uint32_t i = 0 ; i < n_kv; ++i) {
3150+ const uint32_t cell_id = i + kv_self.head ;
3151+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3152+
3153+ // prevent out-of-bound sources
3154+ if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
3155+ kv_cell.src = cell_id;
3156+ }
3157+
3158+ data[i] = kv_cell.src ;
3159+
3160+ // TODO: do not mutate the KV cache
3161+ // ensure copy only happens once
3162+ if (kv_cell.src != (int32_t ) cell_id) {
3163+ kv_cell.src = cell_id;
3164+ }
3165+ }
3166+ }
3167+ }
3168+
3169+ ggml_tensor * llama_context_rwkv::build_inp_s_copy (
31563170 ggml_context * ctx0,
31573171 bool worst_case) {
31583172 const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3163,7 +3177,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
31633177 return inp_s_copy;
31643178}
31653179
3166- ggml_tensor * llama_context_kv_self ::build_inp_s_mask (
3180+ ggml_tensor * llama_context_rwkv ::build_inp_s_mask (
31673181 ggml_context * ctx0,
31683182 bool worst_case) {
31693183 const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3173,7 +3187,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
31733187 return inp_s_mask;
31743188}
31753189
3176- ggml_tensor * llama_context_kv_self ::build_copy_mask_state (
3190+ ggml_tensor * llama_context_rwkv ::build_copy_mask_state (
31773191 ggml_context * ctx0,
31783192 ggml_cgraph * gf,
31793193 ggml_tensor * s,
@@ -3208,7 +3222,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
32083222}
32093223
32103224// TODO: split
3211- ggml_tensor * llama_context_kv_self ::build_mamba_layer (
3225+ ggml_tensor * llama_context_rwkv ::build_mamba_layer (
32123226 ggml_context * ctx0,
32133227 ggml_cgraph * gf,
32143228 ggml_tensor * cur,
@@ -3344,7 +3358,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
33443358}
33453359
33463360
3347- ggml_tensor * llama_context_kv_self ::build_rwkv_token_shift_load (
3361+ ggml_tensor * llama_context_rwkv ::build_rwkv_token_shift_load (
33483362 ggml_context * ctx0,
33493363 ggml_cgraph * gf,
33503364 ggml_tensor * state_copy,
@@ -3371,7 +3385,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
33713385}
33723386
33733387
3374- ggml_tensor * llama_context_kv_self ::build_rwkv_token_shift_store (
3388+ ggml_tensor * llama_context_rwkv ::build_rwkv_token_shift_store (
33753389 ggml_context * ctx0,
33763390 ggml_tensor * token_shift,
33773391 const llama_ubatch & ubatch,
@@ -3395,7 +3409,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
33953409}
33963410
33973411
3398- ggml_tensor * llama_context_kv_self ::build_rwkv6_time_mix (
3412+ ggml_tensor * llama_context_rwkv ::build_rwkv6_time_mix (
33993413 ggml_context * ctx0,
34003414 ggml_cgraph * gf,
34013415 ggml_tensor * cur,
0 commit comments