Skip to content

Commit 2d597b3

Browse files
ggerganovMinh141120
authored andcommitted
memory : rename interface to llama_memory_context_i (ggml-org#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
1 parent e5be49f commit 2d597b3

13 files changed

+263
-173
lines changed

src/llama-context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
937937
llama_memory_context_ptr mctx;
938938

939939
while (true) {
940-
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
941-
if (!mstate) {
940+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
941+
if (!mctx) {
942942
return -2;
943943
}
944944

@@ -2043,8 +2043,8 @@ void llama_context::opt_epoch_iter(
20432043

20442044
uint32_t n_outputs_all = n_tokens_all;
20452045

2046-
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
2047-
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2046+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2047+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20482048
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20492049
break;
20502050
}

src/llama-graph.cpp

Lines changed: 113 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
221221
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
222222
GGML_UNUSED(ubatch);
223223

224-
const int64_t n_rs = mem_state->get_n_rs();
224+
const int64_t n_rs = mctx->get_n_rs();
225225

226226
if (s_copy) {
227227
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
228228
int32_t * data = (int32_t *) s_copy->data;
229229

230230
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231231
for (uint32_t i = 0; i < n_rs; ++i) {
232-
data[i] = mem_state->s_copy(i);
232+
data[i] = mctx->s_copy(i);
233233
}
234234
}
235235
}
@@ -338,18 +338,18 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
338338

339339
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
340340
if (self_kq_mask) {
341-
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
341+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
342342
}
343343

344-
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
344+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
345345

346346
if (s_copy) {
347347
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
348348
int32_t * data = (int32_t *) s_copy->data;
349349

350350
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
351351
for (uint32_t i = 0; i < n_rs; ++i) {
352-
data[i] = mem_state->get_state_recr()->s_copy(i);
352+
data[i] = mctx->get_recr()->s_copy(i);
353353
}
354354
}
355355
}
@@ -999,14 +999,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
999999
}
10001000

10011001
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1002-
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
1002+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
10031003

1004-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
1004+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
10051005

10061006
{
10071007
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10081008

1009-
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
1009+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
10101010

10111011
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10121012
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1016,7 +1016,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10161016
}
10171017

10181018
{
1019-
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1019+
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
10201020

10211021
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
10221022
ggml_set_input(inp->s_copy);
@@ -1297,12 +1297,12 @@ ggml_tensor * llm_graph_context::build_attn(
12971297

12981298
const bool is_swa = hparams.is_swa(il);
12991299

1300-
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1300+
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
13011301

13021302
// store to KV cache
13031303
{
1304-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1305-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1304+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1305+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
13061306
}
13071307

13081308
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1384,18 +1384,97 @@ ggml_tensor * llm_graph_context::build_attn(
13841384
return cur;
13851385
}
13861386

1387-
ggml_tensor * llm_graph_context::build_recurrent_state(
1388-
ggml_cgraph * gf,
1389-
ggml_tensor * s,
1390-
ggml_tensor * state_copy,
1391-
int32_t state_size,
1392-
int32_t n_seqs,
1393-
bool avoid_copies) const {
1394-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1387+
ggml_tensor * llm_graph_context::build_attn(
1388+
llm_graph_input_mem_hybrid * inp,
1389+
ggml_cgraph * gf,
1390+
ggml_tensor * wo,
1391+
ggml_tensor * wo_b,
1392+
ggml_tensor * q_cur,
1393+
ggml_tensor * k_cur,
1394+
ggml_tensor * v_cur,
1395+
ggml_tensor * kq_b,
1396+
ggml_tensor * v_mla,
1397+
float kq_scale,
1398+
int il) const {
1399+
// these nodes are added to the graph together so that they are not reordered
1400+
// by doing so, the number of splits in the graph is reduced
1401+
ggml_build_forward_expand(gf, q_cur);
1402+
ggml_build_forward_expand(gf, k_cur);
1403+
ggml_build_forward_expand(gf, v_cur);
1404+
1405+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1406+
1407+
// store to KV cache
1408+
{
1409+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1410+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1411+
}
1412+
1413+
const auto & kq_mask = inp->get_kq_mask();
1414+
1415+
ggml_tensor * q = q_cur;
1416+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1417+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1418+
1419+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1420+
cb(cur, "kqv_out", il);
1421+
1422+
if (wo) {
1423+
cur = build_lora_mm(wo, cur);
1424+
if (arch == LLM_ARCH_GLM4) {
1425+
// GLM4 seems to have numerical issues with half-precision accumulators
1426+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1427+
}
1428+
}
1429+
1430+
if (wo_b) {
1431+
cur = ggml_add(ctx0, cur, wo_b);
1432+
}
1433+
1434+
return cur;
1435+
}
1436+
1437+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1438+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1439+
1440+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1441+
1442+
{
1443+
const auto n_kv = mctx_cur->get_base()->get_n_kv();
1444+
1445+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1446+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1447+
ggml_set_input(inp->self_kq_mask);
1448+
1449+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1450+
}
13951451

1396-
const auto n_kv = kv_state->get_n_kv();
1397-
const auto kv_head = kv_state->get_head();
1398-
const auto rs_zero = kv_state->get_rs_z();
1452+
{
1453+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1454+
1455+
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1456+
1457+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1458+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1459+
ggml_set_input(inp->self_kq_mask_swa);
1460+
1461+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1462+
}
1463+
1464+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1465+
}
1466+
1467+
ggml_tensor * llm_graph_context::build_rs(
1468+
ggml_cgraph * gf,
1469+
ggml_tensor * s,
1470+
ggml_tensor * state_copy,
1471+
int32_t state_size,
1472+
int32_t n_seqs,
1473+
uint32_t n_kv,
1474+
uint32_t kv_head,
1475+
uint32_t kv_size,
1476+
int32_t rs_zero,
1477+
bool avoid_copies) const {
13991478

14001479
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
14011480

@@ -1426,11 +1505,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14261505
}
14271506

14281507
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1429-
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1508+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
14301509

1431-
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1510+
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
14321511

1433-
const auto n_rs = kv_state->get_n_rs();
1512+
const auto n_rs = mctx_cur->get_n_rs();
14341513

14351514
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
14361515
ggml_set_input(inp->s_copy);
@@ -1445,9 +1524,9 @@ ggml_tensor * llm_graph_context::build_rs(
14451524
int32_t state_size,
14461525
int32_t n_seqs,
14471526
bool avoid_copies) const {
1448-
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1527+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
14491528

1450-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1529+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
14511530
}
14521531

14531532
ggml_tensor * llm_graph_context::build_rs(
@@ -1457,23 +1536,23 @@ ggml_tensor * llm_graph_context::build_rs(
14571536
int32_t state_size,
14581537
int32_t n_seqs,
14591538
bool avoid_copies) const {
1460-
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1539+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
14611540

1462-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1541+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
14631542
}
14641543

14651544
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14661545
llm_graph_input_rs * inp,
14671546
ggml_cgraph * gf,
14681547
const llama_ubatch & ubatch,
14691548
int il) const {
1470-
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1549+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
14711550

14721551
const auto token_shift_count = hparams.token_shift_count;
14731552

14741553
const int64_t n_seqs = ubatch.n_seqs;
14751554

1476-
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1555+
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
14771556

14781557
ggml_tensor * token_shift = build_rs(
14791558
inp, gf, token_shift_all,
@@ -1488,7 +1567,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14881567
ggml_tensor * token_shift,
14891568
const llama_ubatch & ubatch,
14901569
int il) const {
1491-
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1570+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
14921571

14931572
const auto token_shift_count = hparams.token_shift_count;
14941573
const auto n_embd = hparams.n_embd;
@@ -1500,7 +1579,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15001579
return ggml_cpy(
15011580
ctx0,
15021581
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1503-
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1582+
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
15041583
);
15051584
}
15061585

src/llama-graph.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ struct llama_cparams;
1919

2020
struct llama_memory_context_i;
2121

22-
class llama_kv_cache_unified_state;
23-
class llama_kv_cache_unified_iswa_state;
24-
class llama_memory_recurrent_state;
25-
class llama_memory_hybrid_state;
22+
class llama_kv_cache_unified_context;
23+
class llama_kv_cache_unified_iswa_context;
24+
class llama_memory_recurrent_context;
25+
class llama_memory_hybrid_context;
2626

2727
// certain models (typically multi-modal) can produce different types of graphs
2828
enum llm_graph_type {
@@ -193,14 +193,14 @@ class llm_graph_input_cls : public llm_graph_input_i {
193193
class llm_graph_input_rs : public llm_graph_input_i {
194194
class llm_graph_input_rs : public llm_graph_input_i {
195195
public:
196-
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
196+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
197197
virtual ~llm_graph_input_rs() = default;
198198

199199
void set_input(const llama_ubatch * ubatch) override;
200200

201201
ggml_tensor * s_copy; // I32 [kv_size]
202202

203-
const llama_memory_recurrent_state * mem_state;
203+
const llama_memory_recurrent_context * mctx;
204204
};
205205

206206
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -308,10 +308,10 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
308308
llm_graph_input_mem_hybrid(
309309
const llama_hparams & hparams,
310310
const llama_cparams & cparams,
311-
const llama_memory_hybrid_state * mem_state) :
311+
const llama_memory_hybrid_context * mctx) :
312312
hparams(hparams),
313313
cparams(cparams),
314-
mem_state(mem_state) {
314+
mctx(mctx) {
315315
}
316316
virtual ~llm_graph_input_mem_hybrid() = default;
317317

@@ -327,7 +327,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
327327
const llama_hparams & hparams;
328328
const llama_cparams & cparams;
329329

330-
const llama_memory_hybrid_state * mem_state;
330+
const llama_memory_hybrid_context * mctx;
331331
};
332332

333333
//

0 commit comments

Comments
 (0)