Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}

if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
Expand Down Expand Up @@ -1041,7 +1041,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_self_clear(lctx);
llama_memory_clear(llama_get_memory(lctx));
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
Expand Down
10 changes: 6 additions & 4 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft(
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;

auto * mem = llama_get_memory(ctx);

int reuse_i = 0;
int reuse_n = 0;

Expand Down Expand Up @@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft(
result.reserve(params.n_draft);

if (reuse_n == 0) {
llama_kv_self_clear(ctx);
llama_memory_clear(mem);

prompt.clear();
} else {
Expand All @@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft(
}

if (reuse_i > 0) {
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
llama_memory_seq_rm (mem, 0, 0, reuse_i);
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);

prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}

if (reuse_n < (int) prompt.size()) {
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
llama_memory_seq_rm (mem, 0, reuse_n, -1);

prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
Expand Down
2 changes: 1 addition & 1 deletion examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
}

for i in 1 ..< n_parallel {
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens)
}

if n_parallel > 1 {
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
llama_memory_clear(llama_get_memory(ctx));

// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
Expand Down
4 changes: 2 additions & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
llama_memory_clear(llama_get_memory(ctx));
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);

Expand Down Expand Up @@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std

llama_token eos_token = llama_vocab_eos(vocab);

llama_kv_self_clear(ctx);
llama_memory_clear(llama_get_memory(ctx));
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

Expand Down
8 changes: 4 additions & 4 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
}

batch->logits[batch->n_tokens - 1] = true;
llama_kv_self_clear(context);
llama_memory_clear(llama_get_memory(context));

const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
Expand All @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(

LOGi("Benchmark text generation (tg)");

llama_kv_self_clear(context);
llama_memory_clear(llama_get_memory(context));
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {

Expand All @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(

const auto t_tg_end = ggml_time_us();

llama_kv_self_clear(context);
llama_memory_clear(llama_get_memory(context));

const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
Expand Down Expand Up @@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)));
}
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ actor LlamaContext {
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true

llama_kv_self_clear(context)
llama_memory_clear(llama_get_memory(context))

let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;

Expand All @@ -223,7 +223,7 @@ actor LlamaContext {

// bench text generation

llama_kv_self_clear(context)
llama_memory_clear(llama_get_memory(context))

let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;

Expand All @@ -242,7 +242,7 @@ actor LlamaContext {

let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;

llama_kv_self_clear(context)
llama_memory_clear(llama_get_memory(context))

let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
Expand Down Expand Up @@ -292,7 +292,7 @@ actor LlamaContext {
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_self_clear(context)
llama_memory_clear(llama_get_memory(context))
}

private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
Expand Down
14 changes: 8 additions & 6 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();

auto * mem = llama_get_memory(ctx);

const llama_vocab * vocab = llama_model_get_vocab(model);

// Tokenize the prompt
Expand Down Expand Up @@ -94,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
llama_memory_seq_cp(mem, 0, s, -1, -1);
}

const auto t_enc_end = ggml_time_us();
Expand Down Expand Up @@ -427,17 +429,17 @@ int main(int argc, char ** argv) {

// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
llama_memory_seq_rm(mem, -1, n_past, -1);

if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_self_seq_keep(ctx, seq_id_best);
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
llama_memory_seq_keep(mem, seq_id_best);
llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1);
llama_memory_seq_rm (mem, seq_id_best, -1, -1);

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
llama_memory_seq_cp(mem, 0, s, -1, -1);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ int main(int argc, char ** argv){

// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1);

common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
Expand Down
12 changes: 7 additions & 5 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();

auto * mem = llama_get_memory(ctx);

const llama_vocab * vocab = llama_model_get_vocab(model);

// load the prompts from an external file if there are any
Expand Down Expand Up @@ -259,7 +261,7 @@ int main(int argc, char ** argv) {

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
llama_memory_seq_cp(mem, 0, i, -1, -1);
}

LOG_INF("\n");
Expand All @@ -286,9 +288,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_self_seq_rm(ctx, i, -1, -1);
llama_memory_seq_rm(mem, i, -1, -1);
// but keep the system prompt
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
llama_memory_seq_cp(mem, 0, i, -1, -1);
}

LOG_INF("%s: clearing the KV cache\n", __func__);
Expand Down Expand Up @@ -447,8 +449,8 @@ int main(int argc, char ** argv) {
}

// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);

const auto t_main_end = ggml_time_us();

Expand Down
20 changes: 11 additions & 9 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,19 @@ int main(int argc, char ** argv) {

int n_past = 0;

auto * mem = llama_get_memory(ctx);

// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
if (i > 0 && n_grp > 1) {
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);

llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd);
llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
}

common_batch_clear(batch);
Expand Down Expand Up @@ -166,10 +168,10 @@ int main(int argc, char ** argv) {

LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);

llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
n_past = llama_memory_seq_pos_max(mem, 0) + 1;

common_batch_clear(batch);

Expand All @@ -195,10 +197,10 @@ int main(int argc, char ** argv) {
if (n_discard > 0) {
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);

llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke

static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
llama_memory_clear(llama_get_memory(ctx));

// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
Expand Down
2 changes: 1 addition & 1 deletion examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);

// erase whole kv
llama_kv_self_clear(ctx3);
llama_memory_clear(llama_get_memory(ctx3));
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
Expand Down
4 changes: 2 additions & 2 deletions examples/simple-chat/simple-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;

const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;

// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
Expand All @@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
}

if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
Expand Down
Loading
Loading