Skip to content

Commit cfa0980

Browse files
committed
llama : cont
ggml-ci
1 parent 4340e63 commit cfa0980

File tree

19 files changed

+126
-78
lines changed

19 files changed

+126
-78
lines changed

examples/batched-bench/batched-bench.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ int main(int argc, char ** argv) {
5757
return 1;
5858
}
5959

60+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
61+
6062
const int32_t n_kv_max = llama_n_ctx(ctx);
6163

6264
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@@ -132,7 +134,7 @@ int main(int argc, char ** argv) {
132134

133135
const auto t_pp_start = ggml_time_us();
134136

135-
llama_kv_cache_clear(ctx);
137+
llama_kv_cache_clear(kv);
136138

137139
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
138140
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +143,7 @@ int main(int argc, char ** argv) {
141143

142144
if (is_pp_shared) {
143145
for (int32_t i = 1; i < pl; ++i) {
144-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
146+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
145147
}
146148
}
147149

examples/cvector-generator/cvector-generator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache_clear(ctx);
345+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
346+
llama_kv_cache_clear(kv);
346347
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
347348
fprintf(stderr, "%s : failed to eval\n", __func__);
348349
return false;

examples/gritlm/gritlm.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1313
const llama_model * model = llama_get_model(ctx);
1414
const llama_vocab * vocab = llama_model_get_vocab(model);
1515

16+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
17+
1618
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
1719

1820
for (uint64_t i = 0; i < sentences.size(); i++) {
@@ -45,7 +47,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4547
}
4648

4749
// clear previous kv_cache values (irrelevant for embeddings)
48-
llama_kv_cache_clear(ctx);
50+
llama_kv_cache_clear(kv);
4951
llama_set_embeddings(ctx, true);
5052
llama_set_causal_attn(ctx, false);
5153

@@ -100,9 +102,11 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
100102
const llama_model * model = llama_get_model(ctx);
101103
const llama_vocab * vocab = llama_model_get_vocab(model);
102104

105+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
106+
103107
llama_token eos_token = llama_vocab_eos(vocab);
104108

105-
llama_kv_cache_clear(ctx);
109+
llama_kv_cache_clear(kv);
106110
llama_set_embeddings(ctx, false);
107111
llama_set_causal_attn(ctx, true);
108112

examples/imatrix/imatrix.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
431431
const llama_model * model = llama_get_model(ctx);
432432
const llama_vocab * vocab = llama_model_get_vocab(model);
433433

434+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
435+
434436
const bool add_bos = llama_vocab_get_add_bos(vocab);
435437
const int n_ctx = llama_n_ctx(ctx);
436438

@@ -497,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
497499
const auto t_start = std::chrono::high_resolution_clock::now();
498500

499501
// clear the KV cache
500-
llama_kv_cache_clear(ctx);
502+
llama_kv_cache_clear(kv);
501503

502504
llama_batch batch = llama_batch_init(n_batch, 0, 1);
503505

examples/infill/infill.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ int main(int argc, char ** argv) {
139139
return 1;
140140
}
141141

142+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
143+
142144
const llama_vocab * vocab = llama_model_get_vocab(model);
143145

144146
const int n_ctx_train = llama_model_n_ctx_train(model);
@@ -332,8 +334,8 @@ int main(int argc, char ** argv) {
332334
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
333335
n_past, n_left, n_ctx, params.n_keep, n_discard);
334336

335-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337+
llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
338+
llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337339

338340
n_past -= n_discard;
339341

examples/llama-bench/llama-bench.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,9 +1575,11 @@ int main(int argc, char ** argv) {
15751575
return 1;
15761576
}
15771577

1578+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
1579+
15781580
test t(inst, lmodel, ctx);
15791581

1580-
llama_kv_cache_clear(ctx);
1582+
llama_kv_cache_clear(kv);
15811583

15821584
// cool off before the test
15831585
if (params.delay) {
@@ -1617,7 +1619,7 @@ int main(int argc, char ** argv) {
16171619
}
16181620

16191621
for (int i = 0; i < params.reps; i++) {
1620-
llama_kv_cache_clear(ctx);
1622+
llama_kv_cache_clear(kv);
16211623

16221624
uint64_t t_start = get_time_ns();
16231625

examples/lookahead/lookahead.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ int main(int argc, char ** argv) {
6060

6161
llama_model * model = llama_init.model.get();
6262
llama_context * ctx = llama_init.context.get();
63+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
6364

6465
const llama_vocab * vocab = llama_model_get_vocab(model);
6566

@@ -95,7 +96,7 @@ int main(int argc, char ** argv) {
9596
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
9697

9798
for (int s = 1; s < W + G + 1; ++s) {
98-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
99+
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
99100
}
100101

101102
const auto t_enc_end = ggml_time_us();
@@ -437,17 +438,17 @@ int main(int argc, char ** argv) {
437438

438439
// KV cache management
439440
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
440-
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
441+
llama_kv_cache_seq_rm(kv, -1, n_past, -1);
441442

442443
if (seq_id_best != 0) {
443444
// if a verification token matched, we keep the best sequence and remove the rest
444445
// this leads to some KV cache fragmentation
445-
llama_kv_cache_seq_keep(ctx, seq_id_best);
446-
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
447-
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
446+
llama_kv_cache_seq_keep(kv, seq_id_best);
447+
llama_kv_cache_seq_cp (kv, seq_id_best, 0, -1, -1);
448+
llama_kv_cache_seq_rm (kv, seq_id_best, -1, -1);
448449

449450
for (int s = 1; s < W + G + 1; ++s) {
450-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
451+
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
451452
}
452453
}
453454
}

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ int main(int argc, char ** argv){
3535

3636
llama_model * model = llama_init.model.get();
3737
llama_context * ctx = llama_init.context.get();
38+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
3839

3940
const llama_vocab * vocab = llama_model_get_vocab(model);
4041

@@ -192,7 +193,7 @@ int main(int argc, char ** argv){
192193

193194
// KV cache management
194195
// clean the cache of draft tokens that weren't accepted
195-
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
196+
llama_kv_cache_seq_rm(kv, 0, n_past, -1);
196197

197198
common_batch_clear(batch_tgt);
198199
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

examples/main/main.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ int main(int argc, char ** argv) {
164164
return 1;
165165
}
166166

167+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
168+
167169
const llama_vocab * vocab = llama_model_get_vocab(model);
168170

169171
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -326,7 +328,7 @@ int main(int argc, char ** argv) {
326328
}
327329

328330
// remove any "future" tokens that we might have inherited from the previous session
329-
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
331+
llama_kv_cache_seq_rm(kv, -1, n_matching_session_tokens, -1);
330332
}
331333

332334
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -567,8 +569,8 @@ int main(int argc, char ** argv) {
567569
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
568570
n_past, n_left, n_ctx, params.n_keep, n_discard);
569571

570-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
571-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
572+
llama_kv_cache_seq_rm (kv, 0, params.n_keep , params.n_keep + n_discard);
573+
llama_kv_cache_seq_add(kv, 0, params.n_keep + n_discard, n_past, -n_discard);
572574

573575
n_past -= n_discard;
574576

@@ -591,9 +593,9 @@ int main(int argc, char ** argv) {
591593
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
592594
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
593595

594-
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
595-
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
596-
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
596+
llama_kv_cache_seq_add(kv, 0, ga_i, n_past, ib*bd);
597+
llama_kv_cache_seq_div(kv, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
598+
llama_kv_cache_seq_add(kv, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
597599

598600
n_past -= bd;
599601

examples/parallel/parallel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ int main(int argc, char ** argv) {
134134

135135
llama_model * model = llama_init.model.get();
136136
llama_context * ctx = llama_init.context.get();
137+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
137138

138139
const llama_vocab * vocab = llama_model_get_vocab(model);
139140

@@ -201,7 +202,7 @@ int main(int argc, char ** argv) {
201202

202203
// assign the system KV cache to all parallel sequences
203204
for (int32_t i = 1; i <= n_clients; ++i) {
204-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
205+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
205206
}
206207

207208
LOG_INF("\n");
@@ -233,9 +234,9 @@ int main(int argc, char ** argv) {
233234
if (batch.n_tokens == 0) {
234235
// all sequences have ended - clear the entire KV cache
235236
for (int i = 1; i <= n_clients; ++i) {
236-
llama_kv_cache_seq_rm(ctx, i, -1, -1);
237+
llama_kv_cache_seq_rm(kv, i, -1, -1);
237238
// but keep the system prompt
238-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
239+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
239240
}
240241

241242
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -371,8 +372,8 @@ int main(int argc, char ** argv) {
371372
}
372373

373374
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
374-
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
375-
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
375+
llama_kv_cache_seq_rm(kv, client.id + 1, -1, -1);
376+
llama_kv_cache_seq_cp(kv, 0, client.id + 1, -1, -1);
376377

377378
const auto t_main_end = ggml_time_us();
378379

0 commit comments

Comments
 (0)