Skip to content

Commit 8006f3b

Browse files
committed
llama : remove implicit recurrent state rollbacks
1 parent 124c222 commit 8006f3b

File tree

25 files changed

+399
-1107
lines changed

25 files changed

+399
-1107
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) {
966966
if (llama_model_has_decoder(model)) {
967967
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
968968
}
969-
llama_past_clear(lctx);
969+
llama_kv_cache_clear(lctx);
970970
llama_synchronize(lctx);
971971
llama_perf_context_reset(lctx);
972972
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ int main(int argc, char ** argv) {
133133

134134
const auto t_pp_start = ggml_time_us();
135135

136-
llama_past_clear(ctx);
136+
llama_kv_cache_clear(ctx);
137137

138138
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
139139
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -142,7 +142,7 @@ int main(int argc, char ** argv) {
142142

143143
if (is_pp_shared) {
144144
for (int32_t i = 1; i < pl; ++i) {
145-
llama_past_seq_cp(ctx, 0, i, -1, -1);
145+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
146146
}
147147
}
148148

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 {
111111
}
112112

113113
for i in 1 ..< n_parallel {
114-
llama_past_seq_cp(context, 0, Int32(i), -1, -1)
114+
llama_kv_cache_seq_cp(context, 0, Int32(i), -1, -1)
115115
}
116116

117117
if n_parallel > 1 {

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
138138
//// assign the system KV cache to all parallel sequences
139139
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
140140
//for (int32_t i = 1; i < n_parallel; ++i) {
141-
// llama_past_seq_cp(ctx, 0, i, -1, -1);
141+
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
142142
//}
143143

144144
if (n_parallel > 1) {

examples/cvector-generator/cvector-generator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
338338
}
339339

340340
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
341-
llama_past_clear(ctx);
341+
llama_kv_cache_clear(ctx);
342342
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
343343
fprintf(stderr, "%s : failed to eval\n", __func__);
344344
return false;

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
3737
const struct llama_model * model = llama_get_model(ctx);
3838

3939
// clear previous kv_cache values (irrelevant for embeddings)
40-
llama_past_clear(ctx);
40+
llama_kv_cache_clear(ctx);
4141

4242
// run model
4343
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4444
}
4545

4646
// clear previous kv_cache values (irrelevant for embeddings)
47-
llama_past_clear(ctx);
47+
llama_kv_cache_clear(ctx);
4848
llama_set_embeddings(ctx, true);
4949
llama_set_causal_attn(ctx, false);
5050

@@ -99,7 +99,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
9999
const llama_model * model = llama_get_model(ctx);
100100
llama_token eos_token = llama_token_eos(model);
101101

102-
llama_past_clear(ctx);
102+
llama_kv_cache_clear(ctx);
103103
llama_set_embeddings(ctx, false);
104104
llama_set_causal_attn(ctx, true);
105105

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
494494
const auto t_start = std::chrono::high_resolution_clock::now();
495495

496496
// clear the KV cache
497-
llama_past_clear(ctx);
497+
llama_kv_cache_clear(ctx);
498498

499499
for (int j = 0; j < num_batches; ++j) {
500500
const int batch_start = start + j * n_batch;

examples/infill/infill.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ int main(int argc, char ** argv) {
375375
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
376376
n_past, n_left, n_ctx, params.n_keep, n_discard);
377377

378-
llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
379-
llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
378+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
379+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
380380

381381
n_past -= n_discard;
382382

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,7 @@ int main(int argc, char ** argv) {
15661566

15671567
test t(inst, lmodel, ctx);
15681568

1569-
llama_past_clear(ctx);
1569+
llama_kv_cache_clear(ctx);
15701570

15711571
// cool off before the test
15721572
if (params.delay) {
@@ -1606,7 +1606,7 @@ int main(int argc, char ** argv) {
16061606
}
16071607

16081608
for (int i = 0; i < params.reps; i++) {
1609-
llama_past_clear(ctx);
1609+
llama_kv_cache_clear(ctx);
16101610

16111611
uint64_t t_start = get_time_ns();
16121612

0 commit comments

Comments
 (0)