Skip to content

Commit 26eb48c

Browse files
committed
talk-llama : sync llama.cpp
ggml-ci
1 parent 546928c commit 26eb48c

18 files changed

+1951
-1161
lines changed

examples/talk-llama/llama-batch.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "llama-batch.h"
22

3+
#include <cassert>
34
#include <cstring>
45
#include <algorithm>
56

@@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
281282
batch = in_batch;
282283
GGML_ASSERT(batch.n_tokens > 0);
283284
if (!batch.pos) {
285+
assert(p0 >= 0);
284286
pos.resize(batch.n_tokens);
285287
for (int32_t i = 0; i < batch.n_tokens; i++) {
286-
pos[i] = i + p0;
288+
pos[i] = p0 + i;
287289
}
288290
batch.pos = pos.data();
289291
}

examples/talk-llama/llama-context.cpp

Lines changed: 79 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ llama_context::llama_context(
2525

2626
const auto & hparams = model.hparams;
2727

28-
cparams.n_seq_max = std::max(1u, params.n_seq_max);
28+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
29+
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
31+
}
32+
2933
cparams.n_threads = params.n_threads;
3034
cparams.n_threads_batch = params.n_threads_batch;
3135
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -93,6 +97,7 @@ llama_context::llama_context(
9397
}
9498

9599
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
100+
96101
cparams.op_offload = params.op_offload;
97102

98103
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +181,9 @@ llama_context::llama_context(
176181
// init the memory module
177182
if (!hparams.vocab_only) {
178183
llama_memory_params params_mem = {
179-
/*.type_k =*/ params.type_k,
180-
/*.type_v =*/ params.type_v,
184+
/*.type_k =*/ params.type_k,
185+
/*.type_v =*/ params.type_v,
186+
/*.swa_full =*/ params.swa_full,
181187
};
182188

183189
memory.reset(model.create_memory(params_mem, cparams));
@@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
687693

688694
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
689695

696+
// TODO: move the validation to the llama_batch_allocr
690697
if (batch.token) {
691698
for (int32_t i = 0; i < n_tokens; ++i) {
692699
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
693700
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
694701
return -1;
695702
}
703+
704+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706+
throw -1;
707+
}
696708
}
697709
}
698710

@@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
846858

847859
int llama_context::decode(llama_batch & inp_batch) {
848860
if (!memory) {
849-
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
861+
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
850862
return encode(inp_batch);
851863
}
852864

@@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
855867
return -1;
856868
}
857869

870+
if (!inp_batch.pos) {
871+
if (inp_batch.seq_id) {
872+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873+
return -1;
874+
}
875+
}
876+
858877
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
859878

860879
// temporary allocate memory for the input batch if needed
861-
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
880+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
863881

864882
const llama_batch & batch = batch_allocr.batch;
865883

@@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
875893

876894
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
877895

896+
// TODO: move the validation to the llama_batch_allocr
878897
if (batch.token) {
879898
for (int64_t i = 0; i < n_tokens_all; ++i) {
880899
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
881900
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
882-
throw std::runtime_error("invalid token");
901+
return -1;
902+
}
903+
904+
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905+
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906+
return -1;
883907
}
884908
}
885909
}
@@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
947971

948972
// find KV slot
949973
if (!kv_self->find_slot(ubatch)) {
950-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
951-
952974
return 1;
953975
}
954976

@@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() {
20932115
/*.flash_attn =*/ false,
20942116
/*.no_perf =*/ true,
20952117
/*.op_offload =*/ true,
2118+
/*.swa_full =*/ true,
20962119
};
20972120

20982121
return result;
@@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
22872310
return res ? 0 : -1;
22882311
}
22892312

2290-
//
2291-
// kv cache view
2292-
//
2293-
2294-
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2295-
const auto * kv = ctx->get_kv_self();
2296-
if (kv == nullptr) {
2297-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2298-
return {};
2299-
}
2300-
2301-
return llama_kv_cache_view_init(*kv, n_seq_max);
2302-
}
2303-
2304-
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2305-
const auto * kv = ctx->get_kv_self();
2306-
if (kv == nullptr) {
2307-
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2308-
return;
2309-
}
2310-
2311-
llama_kv_cache_view_update(view, kv);
2312-
}
2313-
23142313
//
23152314
// kv cache
23162315
//
23172316

23182317
// deprecated
2319-
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2320-
return llama_kv_self_n_tokens(ctx);
2321-
}
2322-
23232318
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
23242319
const auto * kv = ctx->get_kv_self();
23252320
if (!kv) {
23262321
return 0;
23272322
}
23282323

2329-
return kv->get_n_tokens();
2330-
}
2324+
int32_t res = 0;
23312325

2332-
// deprecated
2333-
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2334-
return llama_kv_self_used_cells(ctx);
2326+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2327+
const llama_pos p0 = kv->seq_pos_min(s);
2328+
const llama_pos p1 = kv->seq_pos_max(s);
2329+
2330+
if (p0 >= 0) {
2331+
res += (p1 - p0) + 1;
2332+
}
2333+
}
2334+
2335+
return res;
23352336
}
23362337

2338+
// deprecated
2339+
// note: this is the same as above - will be removed anyway, so it's ok
23372340
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
23382341
const auto * kv = ctx->get_kv_self();
23392342
if (!kv) {
23402343
return 0;
23412344
}
23422345

2343-
return kv->get_used_cells();
2344-
}
2346+
int32_t res = 0;
23452347

2346-
// deprecated
2347-
void llama_kv_cache_clear(llama_context * ctx) {
2348-
llama_kv_self_clear(ctx);
2348+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2349+
const llama_pos p0 = kv->seq_pos_min(s);
2350+
const llama_pos p1 = kv->seq_pos_max(s);
2351+
2352+
if (p0 >= 0) {
2353+
res += (p1 - p0) + 1;
2354+
}
2355+
}
2356+
2357+
return res;
23492358
}
23502359

23512360
void llama_kv_self_clear(llama_context * ctx) {
@@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
23572366
kv->clear();
23582367
}
23592368

2360-
// deprecated
2361-
bool llama_kv_cache_seq_rm(
2362-
llama_context * ctx,
2363-
llama_seq_id seq_id,
2364-
llama_pos p0,
2365-
llama_pos p1) {
2366-
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2367-
}
2368-
23692369
bool llama_kv_self_seq_rm(
23702370
llama_context * ctx,
23712371
llama_seq_id seq_id,
@@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm(
23792379
return kv->seq_rm(seq_id, p0, p1);
23802380
}
23812381

2382-
// deprecated
2383-
void llama_kv_cache_seq_cp(
2384-
llama_context * ctx,
2385-
llama_seq_id seq_id_src,
2386-
llama_seq_id seq_id_dst,
2387-
llama_pos p0,
2388-
llama_pos p1) {
2389-
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2390-
}
2391-
23922382
void llama_kv_self_seq_cp(
23932383
llama_context * ctx,
23942384
llama_seq_id seq_id_src,
@@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp(
24032393
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
24042394
}
24052395

2406-
// deprecated
2407-
void llama_kv_cache_seq_keep(
2408-
llama_context * ctx,
2409-
llama_seq_id seq_id) {
2410-
llama_kv_self_seq_keep(ctx, seq_id);
2411-
}
2412-
24132396
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
24142397
auto * kv = ctx->get_kv_self();
24152398
if (!kv) {
@@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
24192402
kv->seq_keep(seq_id);
24202403
}
24212404

2422-
// deprecated
2423-
void llama_kv_cache_seq_add(
2424-
llama_context * ctx,
2425-
llama_seq_id seq_id,
2426-
llama_pos p0,
2427-
llama_pos p1,
2428-
llama_pos delta) {
2429-
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2430-
}
2431-
24322405
void llama_kv_self_seq_add(
24332406
llama_context * ctx,
24342407
llama_seq_id seq_id,
@@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add(
24432416
kv->seq_add(seq_id, p0, p1, delta);
24442417
}
24452418

2446-
// deprecated
2447-
void llama_kv_cache_seq_div(
2448-
llama_context * ctx,
2449-
llama_seq_id seq_id,
2450-
llama_pos p0,
2451-
llama_pos p1,
2452-
int d) {
2453-
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2454-
}
2455-
24562419
void llama_kv_self_seq_div(
24572420
llama_context * ctx,
24582421
llama_seq_id seq_id,
@@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div(
24672430
kv->seq_div(seq_id, p0, p1, d);
24682431
}
24692432

2470-
// deprecated
2471-
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2472-
return llama_kv_self_seq_pos_max(ctx, seq_id);
2433+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434+
const auto * kv = ctx->get_kv_self();
2435+
if (!kv) {
2436+
return -1;
2437+
}
2438+
2439+
return kv->seq_pos_min(seq_id);
24732440
}
24742441

24752442
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24762443
const auto * kv = ctx->get_kv_self();
24772444
if (!kv) {
2478-
return 0;
2445+
return -1;
24792446
}
24802447

24812448
return kv->seq_pos_max(seq_id);
24822449
}
24832450

2484-
// deprecated
2485-
void llama_kv_cache_defrag(llama_context * ctx) {
2486-
llama_kv_self_defrag(ctx);
2487-
}
2488-
24892451
void llama_kv_self_defrag(llama_context * ctx) {
24902452
auto * kv = ctx->get_kv_self();
24912453
if (!kv) {
@@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
24962458
kv->defrag_sched(-1.0f);
24972459
}
24982460

2499-
// deprecated
2500-
bool llama_kv_cache_can_shift(const llama_context * ctx) {
2501-
return llama_kv_self_can_shift(ctx);
2502-
}
2503-
25042461
bool llama_kv_self_can_shift(const llama_context * ctx) {
25052462
const auto * kv = ctx->get_kv_self();
25062463
if (!kv) {
@@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
25102467
return kv->get_can_shift();
25112468
}
25122469

2513-
// deprecated
2514-
void llama_kv_cache_update(llama_context * ctx) {
2515-
llama_kv_self_update(ctx);
2516-
}
2517-
25182470
// llama state API
25192471

25202472
// deprecated
@@ -2637,7 +2589,21 @@ int32_t llama_encode(
26372589
int32_t llama_decode(
26382590
llama_context * ctx,
26392591
llama_batch batch) {
2640-
const int ret = ctx->decode(batch);
2592+
int ret = ctx->decode(batch);
2593+
2594+
// defrag and try again
2595+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2596+
if (ret == 1) {
2597+
llama_kv_self_defrag(ctx);
2598+
ret = ctx->decode(batch);
2599+
2600+
if (ret == 1) {
2601+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602+
2603+
return ret;
2604+
}
2605+
}
2606+
26412607
if (ret != 0) {
26422608
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26432609
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
#include "llama-cparams.h"
2+
3+
size_t llama_max_parallel_sequences(void) {
4+
return LLAMA_MAX_PARALLEL_SEQUENCES;
5+
}

0 commit comments

Comments
 (0)