Skip to content

Commit c724a3d

Browse files
committed
context : add llama_context_rwkv
ggml-ci
1 parent 5f11a55 commit c724a3d

File tree

4 files changed

+210
-77
lines changed

4 files changed

+210
-77
lines changed

src/llama-context.cpp

Lines changed: 75 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,8 +1700,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
17001700
inp_KQ_mask_swa_cnv = nullptr;
17011701
inp_KQ_mask_cross = nullptr;
17021702
inp_k_shift = nullptr;
1703-
inp_s_copy = nullptr;
1704-
inp_s_mask = nullptr;
17051703
inp_embd_enc = nullptr;
17061704
inp_pos_bucket = nullptr;
17071705

@@ -2381,53 +2379,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
23812379
}
23822380
}
23832381

2384-
if (kv_self.recurrent) {
2385-
const int64_t n_kv = kv_self.n;
2386-
2387-
if (inp_s_mask) {
2388-
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
2389-
float * data = (float *) inp_s_mask->data;
2390-
2391-
// clear unused states
2392-
for (int i = 0; i < n_kv; ++i) {
2393-
const uint32_t cell_id = i + kv_self.head;
2394-
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
2395-
2396-
data[i] = (float) (kv_cell.src >= 0);
2397-
2398-
// TODO: do not mutate the KV cache
2399-
// only clear once
2400-
if (kv_cell.src < 0) {
2401-
kv_cell.src = cell_id;
2402-
}
2403-
}
2404-
}
2405-
2406-
if (inp_s_copy) {
2407-
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
2408-
int32_t * data = (int32_t *) inp_s_copy->data;
2409-
2410-
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
2411-
for (uint32_t i = 0; i < n_kv; ++i) {
2412-
const uint32_t cell_id = i + kv_self.head;
2413-
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
2414-
2415-
// prevent out-of-bound sources
2416-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
2417-
kv_cell.src = cell_id;
2418-
}
2419-
2420-
data[i] = kv_cell.src;
2421-
2422-
// TODO: do not mutate the KV cache
2423-
// ensure copy only happens once
2424-
if (kv_cell.src != (int32_t) cell_id) {
2425-
kv_cell.src = cell_id;
2426-
}
2427-
}
2428-
}
2429-
}
2430-
24312382
if (inp_pos_bucket) {
24322383
const int64_t n_tokens = ubatch.n_tokens;
24332384

@@ -2614,7 +2565,7 @@ void llama_context_kv_self::build_attn_inp(
26142565

26152566
void llama_context_kv_self::build_attn_kv_store(
26162567
ggml_context * ctx0,
2617-
ggml_cgraph * graph,
2568+
ggml_cgraph * gf,
26182569
ggml_tensor * k_cur,
26192570
ggml_tensor * v_cur,
26202571
int32_t n_tokens,
@@ -2635,7 +2586,7 @@ void llama_context_kv_self::build_attn_kv_store(
26352586
//cb(k_cache_view, "k_cache_view", il);
26362587

26372588
// note: storing RoPE-ed version of K in the KV cache
2638-
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
2589+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
26392590

26402591
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
26412592

@@ -2653,12 +2604,12 @@ void llama_context_kv_self::build_attn_kv_store(
26532604
}
26542605
//cb(v_cache_view, "v_cache_view", il);
26552606

2656-
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
2607+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
26572608
}
26582609

26592610
ggml_tensor * llama_context_kv_self::build_attn_qkv(
26602611
ggml_context * ctx0,
2661-
ggml_cgraph * graph,
2612+
ggml_cgraph * gf,
26622613
ggml_tensor * wo,
26632614
ggml_tensor * wo_b,
26642615
ggml_tensor * q_cur,
@@ -2791,7 +2742,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
27912742
}
27922743
}
27932744

2794-
ggml_build_forward_expand(graph, cur);
2745+
ggml_build_forward_expand(gf, cur);
27952746

27962747
if (wo) {
27972748
cur = build_lora_mm(ctx0, wo, cur);
@@ -3152,7 +3103,70 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
31523103
return inp_KQ_mask_cross;
31533104
}
31543105

3155-
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
3106+
//
3107+
// llama_context_rwkv
3108+
//
3109+
3110+
ggml_cgraph * llama_context_rwkv::graph_init() {
3111+
inp_s_copy = nullptr;
3112+
inp_s_mask = nullptr;
3113+
3114+
return llama_context_kv_self::graph_init();
3115+
}
3116+
3117+
void llama_context_rwkv::input_set(const llama_ubatch & ubatch) {
3118+
// call base functionality
3119+
llama_context_kv_self::input_set(ubatch);
3120+
3121+
GGML_ASSERT(kv_self.recurrent);
3122+
3123+
const int64_t n_kv = kv_self.n;
3124+
3125+
if (inp_s_mask) {
3126+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
3127+
float * data = (float *) inp_s_mask->data;
3128+
3129+
// clear unused states
3130+
for (int i = 0; i < n_kv; ++i) {
3131+
const uint32_t cell_id = i + kv_self.head;
3132+
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
3133+
3134+
data[i] = (float) (kv_cell.src >= 0);
3135+
3136+
// TODO: do not mutate the KV cache
3137+
// only clear once
3138+
if (kv_cell.src < 0) {
3139+
kv_cell.src = cell_id;
3140+
}
3141+
}
3142+
}
3143+
3144+
if (inp_s_copy) {
3145+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
3146+
int32_t * data = (int32_t *) inp_s_copy->data;
3147+
3148+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
3149+
for (uint32_t i = 0; i < n_kv; ++i) {
3150+
const uint32_t cell_id = i + kv_self.head;
3151+
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
3152+
3153+
// prevent out-of-bound sources
3154+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
3155+
kv_cell.src = cell_id;
3156+
}
3157+
3158+
data[i] = kv_cell.src;
3159+
3160+
// TODO: do not mutate the KV cache
3161+
// ensure copy only happens once
3162+
if (kv_cell.src != (int32_t) cell_id) {
3163+
kv_cell.src = cell_id;
3164+
}
3165+
}
3166+
}
3167+
}
3168+
3169+
ggml_tensor * llama_context_rwkv::build_inp_s_copy(
31563170
ggml_context * ctx0,
31573171
bool worst_case) {
31583172
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3163,7 +3177,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
31633177
return inp_s_copy;
31643178
}
31653179

3166-
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
3180+
ggml_tensor * llama_context_rwkv::build_inp_s_mask(
31673181
ggml_context * ctx0,
31683182
bool worst_case) {
31693183
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3173,7 +3187,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
31733187
return inp_s_mask;
31743188
}
31753189

3176-
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
3190+
ggml_tensor * llama_context_rwkv::build_copy_mask_state(
31773191
ggml_context * ctx0,
31783192
ggml_cgraph * gf,
31793193
ggml_tensor * s,
@@ -3208,7 +3222,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
32083222
}
32093223

32103224
// TODO: split
3211-
ggml_tensor * llama_context_kv_self::build_mamba_layer(
3225+
ggml_tensor * llama_context_rwkv::build_mamba_layer(
32123226
ggml_context * ctx0,
32133227
ggml_cgraph * gf,
32143228
ggml_tensor * cur,
@@ -3344,7 +3358,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
33443358
}
33453359

33463360

3347-
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
3361+
ggml_tensor * llama_context_rwkv::build_rwkv_token_shift_load(
33483362
ggml_context * ctx0,
33493363
ggml_cgraph * gf,
33503364
ggml_tensor * state_copy,
@@ -3371,7 +3385,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
33713385
}
33723386

33733387

3374-
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
3388+
ggml_tensor * llama_context_rwkv::build_rwkv_token_shift_store(
33753389
ggml_context * ctx0,
33763390
ggml_tensor * token_shift,
33773391
const llama_ubatch & ubatch,
@@ -3395,7 +3409,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
33953409
}
33963410

33973411

3398-
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
3412+
ggml_tensor * llama_context_rwkv::build_rwkv6_time_mix(
33993413
ggml_context * ctx0,
34003414
ggml_cgraph * gf,
34013415
ggml_tensor * cur,

src/llama-context.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -433,13 +433,25 @@ class llama_context_kv_self : public llama_context {
433433
int32_t n_tokens,
434434
bool worst_case) override;
435435

436-
// === recurrent ===
436+
protected:
437+
virtual size_t state_get_data(llama_io_write_i & io) override;
438+
virtual size_t state_set_data(llama_io_read_i & io) override;
439+
440+
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
441+
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
442+
};
443+
444+
// TODO: temporary reuse kv_self, but in the future, implement specific context
445+
class llama_context_rwkv : public llama_context_kv_self {
446+
public:
447+
virtual ggml_cgraph * graph_init() override;
448+
449+
virtual void input_set(const llama_ubatch & ubatch) override;
437450

438451
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
439452
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
440453

441454
// TODO: add recurrent cache
442-
// TODO: add mamba-specific llama_context
443455

444456
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
445457
virtual ggml_tensor * build_inp_s_copy(
@@ -497,13 +509,6 @@ class llama_context_kv_self : public llama_context {
497509
const llama_ubatch & ubatch,
498510
int il,
499511
bool worst_case) override;
500-
501-
protected:
502-
virtual size_t state_get_data(llama_io_write_i & io) override;
503-
virtual size_t state_set_data(llama_io_read_i & io) override;
504-
505-
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
506-
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
507512
};
508513

509514
// For internal test use

src/llama-graph.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,115 @@
11
#include "llama-graph.h"
2+
3+
#include "ggml.h"
4+
5+
ggml_tensor * llama_graph_i::build_inp_s_copy(
6+
ggml_context * ctx0,
7+
bool worst_case) {
8+
GGML_UNUSED(ctx0);
9+
GGML_UNUSED(worst_case);
10+
GGML_ABORT("not implemented");
11+
}
12+
13+
ggml_tensor * llama_graph_i::build_inp_s_mask(
14+
ggml_context * ctx0,
15+
bool worst_case) {
16+
GGML_UNUSED(ctx0);
17+
GGML_UNUSED(worst_case);
18+
GGML_ABORT("not implemented");
19+
}
20+
21+
ggml_tensor * llama_graph_i::build_copy_mask_state(
22+
ggml_context * ctx0,
23+
ggml_cgraph * gf,
24+
ggml_tensor * s,
25+
ggml_tensor * state_copy,
26+
ggml_tensor * state_mask,
27+
int32_t n_tokens,
28+
int32_t n_state,
29+
int32_t n_seqs,
30+
bool worst_case) {
31+
GGML_UNUSED(ctx0);
32+
GGML_UNUSED(gf);
33+
GGML_UNUSED(s);
34+
GGML_UNUSED(state_copy);
35+
GGML_UNUSED(state_mask);
36+
GGML_UNUSED(n_tokens);
37+
GGML_UNUSED(n_state);
38+
GGML_UNUSED(n_seqs);
39+
GGML_UNUSED(worst_case);
40+
GGML_ABORT("not implemented");
41+
}
42+
43+
ggml_tensor * llama_graph_i::build_mamba_layer(
44+
ggml_context * ctx0,
45+
ggml_cgraph * gf,
46+
ggml_tensor * cur,
47+
ggml_tensor * state_copy,
48+
ggml_tensor * state_mask,
49+
const llama_ubatch & ubatch,
50+
int il,
51+
bool worst_case) {
52+
GGML_UNUSED(ctx0);
53+
GGML_UNUSED(gf);
54+
GGML_UNUSED(cur);
55+
GGML_UNUSED(state_copy);
56+
GGML_UNUSED(state_mask);
57+
GGML_UNUSED(ubatch);
58+
GGML_UNUSED(il);
59+
GGML_UNUSED(worst_case);
60+
GGML_ABORT("not implemented");
61+
}
62+
63+
ggml_tensor * llama_graph_i::build_rwkv_token_shift_load(
64+
ggml_context * ctx0,
65+
ggml_cgraph * gf,
66+
ggml_tensor * state_copy,
67+
ggml_tensor * state_mask,
68+
const llama_ubatch & ubatch,
69+
int il,
70+
bool worst_case) {
71+
GGML_UNUSED(ctx0);
72+
GGML_UNUSED(gf);
73+
GGML_UNUSED(state_copy);
74+
GGML_UNUSED(state_mask);
75+
GGML_UNUSED(ubatch);
76+
GGML_UNUSED(il);
77+
GGML_UNUSED(worst_case);
78+
GGML_ABORT("not implemented");
79+
}
80+
81+
ggml_tensor * llama_graph_i::build_rwkv_token_shift_store(
82+
ggml_context * ctx0,
83+
ggml_tensor * token_shift,
84+
const llama_ubatch & ubatch,
85+
int il,
86+
bool worst_case) {
87+
GGML_UNUSED(ctx0);
88+
GGML_UNUSED(token_shift);
89+
GGML_UNUSED(ubatch);
90+
GGML_UNUSED(il);
91+
GGML_UNUSED(worst_case);
92+
GGML_ABORT("not implemented");
93+
}
94+
95+
ggml_tensor * llama_graph_i::build_rwkv6_time_mix(
96+
ggml_context * ctx0,
97+
ggml_cgraph * gf,
98+
ggml_tensor * cur,
99+
ggml_tensor * x_prev,
100+
ggml_tensor * state_copy,
101+
ggml_tensor * state_mask,
102+
const llama_ubatch & ubatch,
103+
int il,
104+
bool worst_case) {
105+
GGML_UNUSED(ctx0);
106+
GGML_UNUSED(gf);
107+
GGML_UNUSED(cur);
108+
GGML_UNUSED(x_prev);
109+
GGML_UNUSED(state_copy);
110+
GGML_UNUSED(state_mask);
111+
GGML_UNUSED(ubatch);
112+
GGML_UNUSED(il);
113+
GGML_UNUSED(worst_case);
114+
GGML_ABORT("not implemented");
115+
}

0 commit comments

Comments
 (0)