Skip to content

Commit e17e4b7

Browse files
committed
context : add llama_context_recurrent
ggml-ci
1 parent 5f11a55 commit e17e4b7

File tree

5 files changed

+266
-83
lines changed

5 files changed

+266
-83
lines changed

src/llama-context.cpp

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ llama_context::llama_context(
2020
model (model),
2121
t_start_us(model.t_start_us),
2222
t_load_us (model.t_load_us) {
23+
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
24+
2325
const auto & hparams = model.hparams;
2426

2527
cparams.n_seq_max = std::max(1u, params.n_seq_max);
@@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
16331635
const llama_context_params & params) :
16341636
llama_context(model, params),
16351637
kv_self(model.hparams) {
1638+
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
1639+
16361640
const auto & hparams = model.hparams;
16371641

16381642
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
@@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
17001704
inp_KQ_mask_swa_cnv = nullptr;
17011705
inp_KQ_mask_cross = nullptr;
17021706
inp_k_shift = nullptr;
1703-
inp_s_copy = nullptr;
1704-
inp_s_mask = nullptr;
17051707
inp_embd_enc = nullptr;
17061708
inp_pos_bucket = nullptr;
17071709

@@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
23812383
}
23822384
}
23832385

2384-
if (kv_self.recurrent) {
2385-
const int64_t n_kv = kv_self.n;
2386-
2387-
if (inp_s_mask) {
2388-
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
2389-
float * data = (float *) inp_s_mask->data;
2390-
2391-
// clear unused states
2392-
for (int i = 0; i < n_kv; ++i) {
2393-
const uint32_t cell_id = i + kv_self.head;
2394-
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
2395-
2396-
data[i] = (float) (kv_cell.src >= 0);
2397-
2398-
// TODO: do not mutate the KV cache
2399-
// only clear once
2400-
if (kv_cell.src < 0) {
2401-
kv_cell.src = cell_id;
2402-
}
2403-
}
2404-
}
2405-
2406-
if (inp_s_copy) {
2407-
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
2408-
int32_t * data = (int32_t *) inp_s_copy->data;
2409-
2410-
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
2411-
for (uint32_t i = 0; i < n_kv; ++i) {
2412-
const uint32_t cell_id = i + kv_self.head;
2413-
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
2414-
2415-
// prevent out-of-bound sources
2416-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
2417-
kv_cell.src = cell_id;
2418-
}
2419-
2420-
data[i] = kv_cell.src;
2421-
2422-
// TODO: do not mutate the KV cache
2423-
// ensure copy only happens once
2424-
if (kv_cell.src != (int32_t) cell_id) {
2425-
kv_cell.src = cell_id;
2426-
}
2427-
}
2428-
}
2429-
}
2430-
24312386
if (inp_pos_bucket) {
24322387
const int64_t n_tokens = ubatch.n_tokens;
24332388

@@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
26142569

26152570
void llama_context_kv_self::build_attn_kv_store(
26162571
ggml_context * ctx0,
2617-
ggml_cgraph * graph,
2572+
ggml_cgraph * gf,
26182573
ggml_tensor * k_cur,
26192574
ggml_tensor * v_cur,
26202575
int32_t n_tokens,
@@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
26352590
//cb(k_cache_view, "k_cache_view", il);
26362591

26372592
// note: storing RoPE-ed version of K in the KV cache
2638-
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
2593+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
26392594

26402595
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
26412596

@@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
26532608
}
26542609
//cb(v_cache_view, "v_cache_view", il);
26552610

2656-
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
2611+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
26572612
}
26582613

26592614
ggml_tensor * llama_context_kv_self::build_attn_qkv(
26602615
ggml_context * ctx0,
2661-
ggml_cgraph * graph,
2616+
ggml_cgraph * gf,
26622617
ggml_tensor * wo,
26632618
ggml_tensor * wo_b,
26642619
ggml_tensor * q_cur,
@@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
27912746
}
27922747
}
27932748

2794-
ggml_build_forward_expand(graph, cur);
2749+
ggml_build_forward_expand(gf, cur);
27952750

27962751
if (wo) {
27972752
cur = build_lora_mm(ctx0, wo, cur);
@@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
31523107
return inp_KQ_mask_cross;
31533108
}
31543109

3155-
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
3110+
//
3111+
// llama_context_recurrent
3112+
//
3113+
3114+
llama_context_recurrent::llama_context_recurrent(
3115+
const llama_model & model,
3116+
const llama_context_params & params) :
3117+
llama_context_kv_self(model, params) {
3118+
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
3119+
}
3120+
3121+
llama_context_recurrent::~llama_context_recurrent() = default;
3122+
3123+
ggml_cgraph * llama_context_recurrent::graph_init() {
3124+
inp_s_copy = nullptr;
3125+
inp_s_mask = nullptr;
3126+
3127+
return llama_context_kv_self::graph_init();
3128+
}
3129+
3130+
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
3131+
// call base functionality
3132+
llama_context_kv_self::input_set(ubatch);
3133+
3134+
GGML_ASSERT(kv_self.recurrent);
3135+
3136+
const int64_t n_kv = kv_self.n;
3137+
3138+
if (inp_s_mask) {
3139+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
3140+
float * data = (float *) inp_s_mask->data;
3141+
3142+
// clear unused states
3143+
for (int i = 0; i < n_kv; ++i) {
3144+
const uint32_t cell_id = i + kv_self.head;
3145+
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
3146+
3147+
data[i] = (float) (kv_cell.src >= 0);
3148+
3149+
// TODO: do not mutate the KV cache
3150+
// only clear once
3151+
if (kv_cell.src < 0) {
3152+
kv_cell.src = cell_id;
3153+
}
3154+
}
3155+
}
3156+
3157+
if (inp_s_copy) {
3158+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
3159+
int32_t * data = (int32_t *) inp_s_copy->data;
3160+
3161+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
3162+
for (uint32_t i = 0; i < n_kv; ++i) {
3163+
const uint32_t cell_id = i + kv_self.head;
3164+
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
3165+
3166+
// prevent out-of-bound sources
3167+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
3168+
kv_cell.src = cell_id;
3169+
}
3170+
3171+
data[i] = kv_cell.src;
3172+
3173+
// TODO: do not mutate the KV cache
3174+
// ensure copy only happens once
3175+
if (kv_cell.src != (int32_t) cell_id) {
3176+
kv_cell.src = cell_id;
3177+
}
3178+
}
3179+
}
3180+
}
3181+
3182+
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
31563183
ggml_context * ctx0,
31573184
bool worst_case) {
31583185
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
31633190
return inp_s_copy;
31643191
}
31653192

3166-
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
3193+
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
31673194
ggml_context * ctx0,
31683195
bool worst_case) {
31693196
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
31733200
return inp_s_mask;
31743201
}
31753202

3176-
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
3203+
ggml_tensor * llama_context_recurrent::build_copy_mask_state(
31773204
ggml_context * ctx0,
31783205
ggml_cgraph * gf,
31793206
ggml_tensor * s,
@@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
32083235
}
32093236

32103237
// TODO: split
3211-
ggml_tensor * llama_context_kv_self::build_mamba_layer(
3238+
ggml_tensor * llama_context_recurrent::build_mamba_layer(
32123239
ggml_context * ctx0,
32133240
ggml_cgraph * gf,
32143241
ggml_tensor * cur,
@@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
33443371
}
33453372

33463373

3347-
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
3374+
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
33483375
ggml_context * ctx0,
33493376
ggml_cgraph * gf,
33503377
ggml_tensor * state_copy,
@@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
33703397
return token_shift;
33713398
}
33723399

3373-
3374-
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
3400+
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
33753401
ggml_context * ctx0,
33763402
ggml_tensor * token_shift,
33773403
const llama_ubatch & ubatch,
@@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
33943420
);
33953421
}
33963422

3397-
3398-
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
3423+
ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
33993424
ggml_context * ctx0,
34003425
ggml_cgraph * gf,
34013426
ggml_tensor * cur,

src/llama-context.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,15 +433,28 @@ class llama_context_kv_self : public llama_context {
433433
int32_t n_tokens,
434434
bool worst_case) override;
435435

436-
// === recurrent ===
436+
protected:
437+
virtual size_t state_get_data(llama_io_write_i & io) override;
438+
virtual size_t state_set_data(llama_io_read_i & io) override;
437439

438-
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
439-
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
440+
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
441+
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
442+
};
440443

441-
// TODO: add recurrent cache
442-
// TODO: add mamba-specific llama_context
444+
// a recurrent transformer (ie.e RWKV, Mamba)
445+
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
446+
class llama_context_recurrent : public llama_context_kv_self {
447+
public:
448+
llama_context_recurrent(
449+
const llama_model & model,
450+
const llama_context_params & params);
451+
452+
virtual ~llama_context_recurrent();
453+
454+
virtual ggml_cgraph * graph_init() override;
455+
456+
virtual void input_set(const llama_ubatch & ubatch) override;
443457

444-
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
445458
virtual ggml_tensor * build_inp_s_copy(
446459
ggml_context * ctx0,
447460
bool worst_case) override;
@@ -499,11 +512,10 @@ class llama_context_kv_self : public llama_context {
499512
bool worst_case) override;
500513

501514
protected:
502-
virtual size_t state_get_data(llama_io_write_i & io) override;
503-
virtual size_t state_set_data(llama_io_read_i & io) override;
515+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
516+
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
504517

505-
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
506-
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
518+
// TODO: add recurrent cache
507519
};
508520

509521
// For internal test use

0 commit comments

Comments
 (0)