@@ -25,7 +25,11 @@ llama_context::llama_context(
2525
2626 const auto & hparams = model.hparams ;
2727
28- cparams.n_seq_max = std::max (1u , params.n_seq_max );
28+ cparams.n_seq_max = std::max (1u , params.n_seq_max );
29+ if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30+ throw std::runtime_error (" n_seq_max must be <= " + std::to_string (LLAMA_MAX_PARALLEL_SEQUENCES));
31+ }
32+
2933 cparams.n_threads = params.n_threads ;
3034 cparams.n_threads_batch = params.n_threads_batch ;
3135 cparams.yarn_ext_factor = params.yarn_ext_factor ;
@@ -93,6 +97,7 @@ llama_context::llama_context(
9397 }
9498
9599 cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
100+
96101 cparams.op_offload = params.op_offload ;
97102
98103 const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
@@ -176,8 +181,9 @@ llama_context::llama_context(
176181 // init the memory module
177182 if (!hparams.vocab_only ) {
178183 llama_memory_params params_mem = {
179- /* .type_k =*/ params.type_k ,
180- /* .type_v =*/ params.type_v ,
184+ /* .type_k =*/ params.type_k ,
185+ /* .type_v =*/ params.type_v ,
186+ /* .swa_full =*/ params.swa_full ,
181187 };
182188
183189 memory.reset (model.create_memory (params_mem, cparams));
@@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
687693
688694 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
689695
696+ // TODO: move the validation to the llama_batch_allocr
690697 if (batch.token ) {
691698 for (int32_t i = 0 ; i < n_tokens; ++i) {
692699 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
693700 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
694701 return -1 ;
695702 }
703+
704+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d > %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
706+ throw -1 ;
707+ }
696708 }
697709 }
698710
@@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
846858
847859int llama_context::decode (llama_batch & inp_batch) {
848860 if (!memory) {
849- LLAMA_LOG_WARN (" %s: cannot decode batches with this context (use llama_encode () instead)\n " , __func__);
861+ LLAMA_LOG_DEBUG (" %s: cannot decode batches with this context (calling encode () instead)\n " , __func__);
850862 return encode (inp_batch);
851863 }
852864
@@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
855867 return -1 ;
856868 }
857869
870+ if (!inp_batch.pos ) {
871+ if (inp_batch.seq_id ) {
872+ LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
873+ return -1 ;
874+ }
875+ }
876+
858877 llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
859878
860879 // temporary allocate memory for the input batch if needed
861- // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
862- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
880+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max (0 ) + 1 );
863881
864882 const llama_batch & batch = batch_allocr.batch ;
865883
@@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
875893
876894 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
877895
896+ // TODO: move the validation to the llama_batch_allocr
878897 if (batch.token ) {
879898 for (int64_t i = 0 ; i < n_tokens_all; ++i) {
880899 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
881900 LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
882- throw std::runtime_error (" invalid token" );
901+ return -1 ;
902+ }
903+
904+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
906+ return -1 ;
883907 }
884908 }
885909 }
@@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
947971
948972 // find KV slot
949973 if (!kv_self->find_slot (ubatch)) {
950- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
951-
952974 return 1 ;
953975 }
954976
@@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() {
20932115 /* .flash_attn =*/ false ,
20942116 /* .no_perf =*/ true ,
20952117 /* .op_offload =*/ true ,
2118+ /* .swa_full =*/ true ,
20962119 };
20972120
20982121 return result;
@@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
22872310 return res ? 0 : -1 ;
22882311}
22892312
2290- //
2291- // kv cache view
2292- //
2293-
2294- llama_kv_cache_view llama_kv_cache_view_init (const llama_context * ctx, int32_t n_seq_max) {
2295- const auto * kv = ctx->get_kv_self ();
2296- if (kv == nullptr ) {
2297- LLAMA_LOG_WARN (" %s: the context does not have a KV cache\n " , __func__);
2298- return {};
2299- }
2300-
2301- return llama_kv_cache_view_init (*kv, n_seq_max);
2302- }
2303-
2304- void llama_kv_cache_view_update (const llama_context * ctx, llama_kv_cache_view * view) {
2305- const auto * kv = ctx->get_kv_self ();
2306- if (kv == nullptr ) {
2307- LLAMA_LOG_WARN (" %s: the context does not have a KV cache\n " , __func__);
2308- return ;
2309- }
2310-
2311- llama_kv_cache_view_update (view, kv);
2312- }
2313-
23142313//
23152314// kv cache
23162315//
23172316
23182317// deprecated
2319- int32_t llama_get_kv_cache_token_count (const llama_context * ctx) {
2320- return llama_kv_self_n_tokens (ctx);
2321- }
2322-
23232318int32_t llama_kv_self_n_tokens (const llama_context * ctx) {
23242319 const auto * kv = ctx->get_kv_self ();
23252320 if (!kv) {
23262321 return 0 ;
23272322 }
23282323
2329- return kv->get_n_tokens ();
2330- }
2324+ int32_t res = 0 ;
23312325
2332- // deprecated
2333- int32_t llama_get_kv_cache_used_cells (const llama_context * ctx) {
2334- return llama_kv_self_used_cells (ctx);
2326+ for (uint32_t s = 0 ; s < ctx->get_cparams ().n_seq_max ; s++) {
2327+ const llama_pos p0 = kv->seq_pos_min (s);
2328+ const llama_pos p1 = kv->seq_pos_max (s);
2329+
2330+ if (p0 >= 0 ) {
2331+ res += (p1 - p0) + 1 ;
2332+ }
2333+ }
2334+
2335+ return res;
23352336}
23362337
2338+ // deprecated
2339+ // note: this is the same as above - will be removed anyway, so it's ok
23372340int32_t llama_kv_self_used_cells (const llama_context * ctx) {
23382341 const auto * kv = ctx->get_kv_self ();
23392342 if (!kv) {
23402343 return 0 ;
23412344 }
23422345
2343- return kv->get_used_cells ();
2344- }
2346+ int32_t res = 0 ;
23452347
2346- // deprecated
2347- void llama_kv_cache_clear (llama_context * ctx) {
2348- llama_kv_self_clear (ctx);
2348+ for (uint32_t s = 0 ; s < ctx->get_cparams ().n_seq_max ; s++) {
2349+ const llama_pos p0 = kv->seq_pos_min (s);
2350+ const llama_pos p1 = kv->seq_pos_max (s);
2351+
2352+ if (p0 >= 0 ) {
2353+ res += (p1 - p0) + 1 ;
2354+ }
2355+ }
2356+
2357+ return res;
23492358}
23502359
23512360void llama_kv_self_clear (llama_context * ctx) {
@@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
23572366 kv->clear ();
23582367}
23592368
2360- // deprecated
2361- bool llama_kv_cache_seq_rm (
2362- llama_context * ctx,
2363- llama_seq_id seq_id,
2364- llama_pos p0,
2365- llama_pos p1) {
2366- return llama_kv_self_seq_rm (ctx, seq_id, p0, p1);
2367- }
2368-
23692369bool llama_kv_self_seq_rm (
23702370 llama_context * ctx,
23712371 llama_seq_id seq_id,
@@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm(
23792379 return kv->seq_rm (seq_id, p0, p1);
23802380}
23812381
2382- // deprecated
2383- void llama_kv_cache_seq_cp (
2384- llama_context * ctx,
2385- llama_seq_id seq_id_src,
2386- llama_seq_id seq_id_dst,
2387- llama_pos p0,
2388- llama_pos p1) {
2389- llama_kv_self_seq_cp (ctx, seq_id_src, seq_id_dst, p0, p1);
2390- }
2391-
23922382void llama_kv_self_seq_cp (
23932383 llama_context * ctx,
23942384 llama_seq_id seq_id_src,
@@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp(
24032393 kv->seq_cp (seq_id_src, seq_id_dst, p0, p1);
24042394}
24052395
2406- // deprecated
2407- void llama_kv_cache_seq_keep (
2408- llama_context * ctx,
2409- llama_seq_id seq_id) {
2410- llama_kv_self_seq_keep (ctx, seq_id);
2411- }
2412-
24132396void llama_kv_self_seq_keep (llama_context * ctx, llama_seq_id seq_id) {
24142397 auto * kv = ctx->get_kv_self ();
24152398 if (!kv) {
@@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
24192402 kv->seq_keep (seq_id);
24202403}
24212404
2422- // deprecated
2423- void llama_kv_cache_seq_add (
2424- llama_context * ctx,
2425- llama_seq_id seq_id,
2426- llama_pos p0,
2427- llama_pos p1,
2428- llama_pos delta) {
2429- llama_kv_self_seq_add (ctx, seq_id, p0, p1, delta);
2430- }
2431-
24322405void llama_kv_self_seq_add (
24332406 llama_context * ctx,
24342407 llama_seq_id seq_id,
@@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add(
24432416 kv->seq_add (seq_id, p0, p1, delta);
24442417}
24452418
2446- // deprecated
2447- void llama_kv_cache_seq_div (
2448- llama_context * ctx,
2449- llama_seq_id seq_id,
2450- llama_pos p0,
2451- llama_pos p1,
2452- int d) {
2453- llama_kv_self_seq_div (ctx, seq_id, p0, p1, d);
2454- }
2455-
24562419void llama_kv_self_seq_div (
24572420 llama_context * ctx,
24582421 llama_seq_id seq_id,
@@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div(
24672430 kv->seq_div (seq_id, p0, p1, d);
24682431}
24692432
2470- // deprecated
2471- llama_pos llama_kv_cache_seq_pos_max (llama_context * ctx, llama_seq_id seq_id) {
2472- return llama_kv_self_seq_pos_max (ctx, seq_id);
2433+ llama_pos llama_kv_self_seq_pos_min (llama_context * ctx, llama_seq_id seq_id) {
2434+ const auto * kv = ctx->get_kv_self ();
2435+ if (!kv) {
2436+ return -1 ;
2437+ }
2438+
2439+ return kv->seq_pos_min (seq_id);
24732440}
24742441
24752442llama_pos llama_kv_self_seq_pos_max (llama_context * ctx, llama_seq_id seq_id) {
24762443 const auto * kv = ctx->get_kv_self ();
24772444 if (!kv) {
2478- return 0 ;
2445+ return - 1 ;
24792446 }
24802447
24812448 return kv->seq_pos_max (seq_id);
24822449}
24832450
2484- // deprecated
2485- void llama_kv_cache_defrag (llama_context * ctx) {
2486- llama_kv_self_defrag (ctx);
2487- }
2488-
24892451void llama_kv_self_defrag (llama_context * ctx) {
24902452 auto * kv = ctx->get_kv_self ();
24912453 if (!kv) {
@@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
24962458 kv->defrag_sched (-1 .0f );
24972459}
24982460
2499- // deprecated
2500- bool llama_kv_cache_can_shift (const llama_context * ctx) {
2501- return llama_kv_self_can_shift (ctx);
2502- }
2503-
25042461bool llama_kv_self_can_shift (const llama_context * ctx) {
25052462 const auto * kv = ctx->get_kv_self ();
25062463 if (!kv) {
@@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
25102467 return kv->get_can_shift ();
25112468}
25122469
2513- // deprecated
2514- void llama_kv_cache_update (llama_context * ctx) {
2515- llama_kv_self_update (ctx);
2516- }
2517-
25182470// llama state API
25192471
25202472// deprecated
@@ -2637,7 +2589,21 @@ int32_t llama_encode(
26372589int32_t llama_decode (
26382590 llama_context * ctx,
26392591 llama_batch batch) {
2640- const int ret = ctx->decode (batch);
2592+ int ret = ctx->decode (batch);
2593+
2594+ // defrag and try again
2595+ // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596+ if (ret == 1 ) {
2597+ llama_kv_self_defrag (ctx);
2598+ ret = ctx->decode (batch);
2599+
2600+ if (ret == 1 ) {
2601+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for batch of size %d\n " , __func__, batch.n_tokens );
2602+
2603+ return ret;
2604+ }
2605+ }
2606+
26412607 if (ret != 0 ) {
26422608 LLAMA_LOG_ERROR (" %s: failed to decode, ret = %d\n " , __func__, ret);
26432609 }
0 commit comments