Skip to content

Commit 9f976e9

Browse files
committed
swa full used unless ctx shift and fast forward disabled
2 parents 5b6ed44 + e298d2f commit 9f976e9

File tree

16 files changed

+1417
-642
lines changed

16 files changed

+1417
-642
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14461446
params.n_keep = value;
14471447
}
14481448
));
1449+
add_opt(common_arg(
1450+
{"--swa-full"},
1451+
string_format("use full-size SWA cache (default: %s)\n"
1452+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
1453+
[](common_params & params) {
1454+
params.swa_full = true;
1455+
}
1456+
));
14491457
add_opt(common_arg(
14501458
{"--no-context-shift"},
14511459
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11441144
cparams.flash_attn = params.flash_attn;
11451145
cparams.no_perf = params.no_perf;
11461146
cparams.op_offload = !params.no_op_offload;
1147+
cparams.swa_full = params.swa_full;
11471148

11481149
if (params.reranking) {
11491150
cparams.embeddings = true;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ struct common_params {
319319
bool flash_attn = false; // flash attention
320320
bool no_perf = false; // disable performance metrics
321321
bool ctx_shift = true; // context shift on inifinite text generation
322+
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
322323

323324
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
324325
bool use_mmap = true; // use mmap for faster loads

gpttype_adapter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
583583
draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
584584
draft_ctx_params.type_k = base_ctx_params.type_k;
585585
draft_ctx_params.type_v = base_ctx_params.type_v;
586+
draft_ctx_params.swa_full = base_ctx_params.swa_full;
586587

587588
llama_model * draftmodel = llama_model_load_from_file(spec_model_filename.c_str(), draft_model_params);
588589
draft_ctx = llama_init_from_model(draftmodel, draft_ctx_params);
@@ -1923,6 +1924,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
19231924
kcpp_data->use_smartcontext = inputs.use_smartcontext;
19241925
kcpp_data->use_contextshift = inputs.use_contextshift;
19251926
kcpp_data->use_fastforward = inputs.use_fastforward;
1927+
kcpp_data->swa_full = (inputs.use_fastforward || inputs.use_contextshift)?true:false;
19261928
debugmode = inputs.debugmode;
19271929
draft_ctx = nullptr;
19281930
guidance_ctx = nullptr;
@@ -2318,6 +2320,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
23182320
}
23192321

23202322
llama_ctx_params.flash_attn = kcpp_data->flash_attn;
2323+
llama_ctx_params.swa_full = kcpp_data->swa_full;
23212324
llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
23222325
llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
23232326
llama_ctx_v4 = llama_init_from_model(llamamodel, llama_ctx_params);

include/llama.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,11 @@ extern "C" {
363363

364364
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
365365
bool embeddings; // if true, extract embeddings (together with logits)
366-
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
367-
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
368-
bool no_perf; // whether to measure performance timings
369-
bool op_offload; // whether to offload host tensor operations to device
366+
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
367+
bool flash_attn; // use flash attention [EXPERIMENTAL]
368+
bool no_perf; // measure performance timings
369+
bool op_offload; // offload host tensor operations to device
370+
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
370371
};
371372

372373
// model quantization parameters
@@ -732,10 +733,18 @@ extern "C" {
732733
llama_pos p1,
733734
int d);
734735

736+
// Returns the smallest position present in the KV cache for the specified sequence
737+
// This is typically non-zero only for SWA caches
738+
// Return -1 if the sequence is empty
739+
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
740+
struct llama_context * ctx,
741+
llama_seq_id seq_id);
742+
735743
// Returns the largest position present in the KV cache for the specified sequence
744+
// Return -1 if the sequence is empty
736745
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
737746
struct llama_context * ctx,
738-
llama_seq_id seq_id);
747+
llama_seq_id seq_id);
739748

740749
// Defragment the KV cache
741750
// This will be applied:
@@ -945,9 +954,12 @@ extern "C" {
945954
// Requires KV cache.
946955
// For encode-decoder contexts, processes the batch using the decoder.
947956
// Positive return values does not mean a fatal error, but rather a warning.
948-
// 0 - success
949-
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
950-
// < 0 - error. the KV cache state is restored to the state before this call
957+
// Upon non-zero return values, the KV cache state is restored to the state before this call
958+
// 0 - success
959+
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
960+
// 2 - aborted
961+
// -1 - invalid input batch
962+
// < -1 - error
951963
LLAMA_API int32_t llama_decode(
952964
struct llama_context * ctx,
953965
struct llama_batch batch);

otherarch/otherarch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ struct kcpp_params {
5656
bool use_smartcontext = false;
5757
bool use_contextshift = false;
5858
bool use_fastforward = false;
59+
bool swa_full = true;
5960
};
6061

6162
// default hparams (GPT-J 6B)

src/llama-context.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
9697
cparams.op_offload = params.op_offload;
9798

9899
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +177,9 @@ llama_context::llama_context(
176177
// init the memory module
177178
if (!hparams.vocab_only) {
178179
llama_memory_params params_mem = {
179-
/*.type_k =*/ params.type_k,
180-
/*.type_v =*/ params.type_v,
180+
/*.type_k =*/ params.type_k,
181+
/*.type_v =*/ params.type_v,
182+
/*.swa_full =*/ params.swa_full,
181183
};
182184

183185
memory.reset(model.create_memory(params_mem, cparams));
@@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
947949

948950
// find KV slot
949951
if (!kv_self->find_slot(ubatch)) {
950-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
951-
952952
return 1;
953953
}
954954

@@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
20932093
/*.flash_attn =*/ false,
20942094
/*.no_perf =*/ true,
20952095
/*.op_offload =*/ true,
2096+
/*.swa_full =*/ true,
20962097
};
20972098

20982099
return result;
@@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
24672468
kv->seq_div(seq_id, p0, p1, d);
24682469
}
24692470

2471+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2472+
const auto * kv = ctx->get_kv_self();
2473+
if (!kv) {
2474+
return -1;
2475+
}
2476+
2477+
return kv->seq_pos_min(seq_id);
2478+
}
2479+
24702480
// deprecated
24712481
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24722482
return llama_kv_self_seq_pos_max(ctx, seq_id);
@@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24752485
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24762486
const auto * kv = ctx->get_kv_self();
24772487
if (!kv) {
2478-
return 0;
2488+
return -1;
24792489
}
24802490

24812491
return kv->seq_pos_max(seq_id);
@@ -2637,7 +2647,21 @@ int32_t llama_encode(
26372647
int32_t llama_decode(
26382648
llama_context * ctx,
26392649
llama_batch batch) {
2640-
const int ret = ctx->decode(batch);
2650+
int ret = ctx->decode(batch);
2651+
2652+
// defrag and try again
2653+
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2654+
if (ret == 1) {
2655+
llama_kv_self_defrag(ctx);
2656+
ret = ctx->decode(batch);
2657+
2658+
if (ret == 1) {
2659+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2660+
2661+
return ret;
2662+
}
2663+
}
2664+
26412665
if (ret != 0) {
26422666
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26432667
}

0 commit comments

Comments
 (0)