@@ -20,6 +20,8 @@ llama_context::llama_context(
20
20
model (model),
21
21
t_start_us(model.t_start_us),
22
22
t_load_us (model.t_load_us) {
23
+ LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
24
+
23
25
const auto & hparams = model.hparams ;
24
26
25
27
cparams.n_seq_max = std::max (1u , params.n_seq_max );
@@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
1633
1635
const llama_context_params & params) :
1634
1636
llama_context(model, params),
1635
1637
kv_self(model.hparams) {
1638
+ LLAMA_LOG_INFO (" %s: constructing llama_context_kv_self\n " , __func__);
1639
+
1636
1640
const auto & hparams = model.hparams ;
1637
1641
1638
1642
LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
@@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
1700
1704
inp_KQ_mask_swa_cnv = nullptr ;
1701
1705
inp_KQ_mask_cross = nullptr ;
1702
1706
inp_k_shift = nullptr ;
1703
- inp_s_copy = nullptr ;
1704
- inp_s_mask = nullptr ;
1705
1707
inp_embd_enc = nullptr ;
1706
1708
inp_pos_bucket = nullptr ;
1707
1709
@@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
2381
2383
}
2382
2384
}
2383
2385
2384
- if (kv_self.recurrent ) {
2385
- const int64_t n_kv = kv_self.n ;
2386
-
2387
- if (inp_s_mask) {
2388
- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
2389
- float * data = (float *) inp_s_mask->data ;
2390
-
2391
- // clear unused states
2392
- for (int i = 0 ; i < n_kv; ++i) {
2393
- const uint32_t cell_id = i + kv_self.head ;
2394
- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2395
-
2396
- data[i] = (float ) (kv_cell.src >= 0 );
2397
-
2398
- // TODO: do not mutate the KV cache
2399
- // only clear once
2400
- if (kv_cell.src < 0 ) {
2401
- kv_cell.src = cell_id;
2402
- }
2403
- }
2404
- }
2405
-
2406
- if (inp_s_copy) {
2407
- GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
2408
- int32_t * data = (int32_t *) inp_s_copy->data ;
2409
-
2410
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
2411
- for (uint32_t i = 0 ; i < n_kv; ++i) {
2412
- const uint32_t cell_id = i + kv_self.head ;
2413
- llama_kv_cell & kv_cell = kv_self.cells [cell_id];
2414
-
2415
- // prevent out-of-bound sources
2416
- if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
2417
- kv_cell.src = cell_id;
2418
- }
2419
-
2420
- data[i] = kv_cell.src ;
2421
-
2422
- // TODO: do not mutate the KV cache
2423
- // ensure copy only happens once
2424
- if (kv_cell.src != (int32_t ) cell_id) {
2425
- kv_cell.src = cell_id;
2426
- }
2427
- }
2428
- }
2429
- }
2430
-
2431
2386
if (inp_pos_bucket) {
2432
2387
const int64_t n_tokens = ubatch.n_tokens ;
2433
2388
@@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
2614
2569
2615
2570
void llama_context_kv_self::build_attn_kv_store (
2616
2571
ggml_context * ctx0,
2617
- ggml_cgraph * graph ,
2572
+ ggml_cgraph * gf ,
2618
2573
ggml_tensor * k_cur,
2619
2574
ggml_tensor * v_cur,
2620
2575
int32_t n_tokens,
@@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
2635
2590
// cb(k_cache_view, "k_cache_view", il);
2636
2591
2637
2592
// note: storing RoPE-ed version of K in the KV cache
2638
- ggml_build_forward_expand (graph , ggml_cpy (ctx0, k_cur, k_cache_view));
2593
+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, k_cur, k_cache_view));
2639
2594
2640
2595
assert (v_cur->ne [0 ] == n_embd_v_gqa && v_cur->ne [1 ] == n_tokens);
2641
2596
@@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
2653
2608
}
2654
2609
// cb(v_cache_view, "v_cache_view", il);
2655
2610
2656
- ggml_build_forward_expand (graph , ggml_cpy (ctx0, v_cur, v_cache_view));
2611
+ ggml_build_forward_expand (gf , ggml_cpy (ctx0, v_cur, v_cache_view));
2657
2612
}
2658
2613
2659
2614
ggml_tensor * llama_context_kv_self::build_attn_qkv (
2660
2615
ggml_context * ctx0,
2661
- ggml_cgraph * graph ,
2616
+ ggml_cgraph * gf ,
2662
2617
ggml_tensor * wo,
2663
2618
ggml_tensor * wo_b,
2664
2619
ggml_tensor * q_cur,
@@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
2791
2746
}
2792
2747
}
2793
2748
2794
- ggml_build_forward_expand (graph , cur);
2749
+ ggml_build_forward_expand (gf , cur);
2795
2750
2796
2751
if (wo) {
2797
2752
cur = build_lora_mm (ctx0, wo, cur);
@@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
3152
3107
return inp_KQ_mask_cross;
3153
3108
}
3154
3109
3155
- ggml_tensor * llama_context_kv_self::build_inp_s_copy (
3110
+ //
3111
+ // llama_context_recurrent
3112
+ //
3113
+
3114
+ llama_context_recurrent::llama_context_recurrent (
3115
+ const llama_model & model,
3116
+ const llama_context_params & params) :
3117
+ llama_context_kv_self(model, params) {
3118
+ LLAMA_LOG_INFO (" %s: constructing llama_context_recurrent\n " , __func__);
3119
+ }
3120
+
3121
+ llama_context_recurrent::~llama_context_recurrent () = default ;
3122
+
3123
+ ggml_cgraph * llama_context_recurrent::graph_init () {
3124
+ inp_s_copy = nullptr ;
3125
+ inp_s_mask = nullptr ;
3126
+
3127
+ return llama_context_kv_self::graph_init ();
3128
+ }
3129
+
3130
+ void llama_context_recurrent::input_set (const llama_ubatch & ubatch) {
3131
+ // call base functionality
3132
+ llama_context_kv_self::input_set (ubatch);
3133
+
3134
+ GGML_ASSERT (kv_self.recurrent );
3135
+
3136
+ const int64_t n_kv = kv_self.n ;
3137
+
3138
+ if (inp_s_mask) {
3139
+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_mask->buffer ));
3140
+ float * data = (float *) inp_s_mask->data ;
3141
+
3142
+ // clear unused states
3143
+ for (int i = 0 ; i < n_kv; ++i) {
3144
+ const uint32_t cell_id = i + kv_self.head ;
3145
+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3146
+
3147
+ data[i] = (float ) (kv_cell.src >= 0 );
3148
+
3149
+ // TODO: do not mutate the KV cache
3150
+ // only clear once
3151
+ if (kv_cell.src < 0 ) {
3152
+ kv_cell.src = cell_id;
3153
+ }
3154
+ }
3155
+ }
3156
+
3157
+ if (inp_s_copy) {
3158
+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_s_copy->buffer ));
3159
+ int32_t * data = (int32_t *) inp_s_copy->data ;
3160
+
3161
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
3162
+ for (uint32_t i = 0 ; i < n_kv; ++i) {
3163
+ const uint32_t cell_id = i + kv_self.head ;
3164
+ llama_kv_cell & kv_cell = kv_self.cells [cell_id];
3165
+
3166
+ // prevent out-of-bound sources
3167
+ if (kv_cell.src < 0 || (uint32_t ) kv_cell.src >= kv_self.size ) {
3168
+ kv_cell.src = cell_id;
3169
+ }
3170
+
3171
+ data[i] = kv_cell.src ;
3172
+
3173
+ // TODO: do not mutate the KV cache
3174
+ // ensure copy only happens once
3175
+ if (kv_cell.src != (int32_t ) cell_id) {
3176
+ kv_cell.src = cell_id;
3177
+ }
3178
+ }
3179
+ }
3180
+ }
3181
+
3182
+ ggml_tensor * llama_context_recurrent::build_inp_s_copy (
3156
3183
ggml_context * ctx0,
3157
3184
bool worst_case) {
3158
3185
const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
3163
3190
return inp_s_copy;
3164
3191
}
3165
3192
3166
- ggml_tensor * llama_context_kv_self ::build_inp_s_mask (
3193
+ ggml_tensor * llama_context_recurrent ::build_inp_s_mask (
3167
3194
ggml_context * ctx0,
3168
3195
bool worst_case) {
3169
3196
const auto n_kv = worst_case ? kv_self.size : kv_self.n ;
@@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
3173
3200
return inp_s_mask;
3174
3201
}
3175
3202
3176
- ggml_tensor * llama_context_kv_self ::build_copy_mask_state (
3203
+ ggml_tensor * llama_context_recurrent ::build_copy_mask_state (
3177
3204
ggml_context * ctx0,
3178
3205
ggml_cgraph * gf,
3179
3206
ggml_tensor * s,
@@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
3208
3235
}
3209
3236
3210
3237
// TODO: split
3211
- ggml_tensor * llama_context_kv_self ::build_mamba_layer (
3238
+ ggml_tensor * llama_context_recurrent ::build_mamba_layer (
3212
3239
ggml_context * ctx0,
3213
3240
ggml_cgraph * gf,
3214
3241
ggml_tensor * cur,
@@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
3344
3371
}
3345
3372
3346
3373
3347
- ggml_tensor * llama_context_kv_self ::build_rwkv_token_shift_load (
3374
+ ggml_tensor * llama_context_recurrent ::build_rwkv_token_shift_load (
3348
3375
ggml_context * ctx0,
3349
3376
ggml_cgraph * gf,
3350
3377
ggml_tensor * state_copy,
@@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
3370
3397
return token_shift;
3371
3398
}
3372
3399
3373
-
3374
- ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store (
3400
+ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store (
3375
3401
ggml_context * ctx0,
3376
3402
ggml_tensor * token_shift,
3377
3403
const llama_ubatch & ubatch,
@@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
3394
3420
);
3395
3421
}
3396
3422
3397
-
3398
- ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix (
3423
+ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix (
3399
3424
ggml_context * ctx0,
3400
3425
ggml_cgraph * gf,
3401
3426
ggml_tensor * cur,
0 commit comments