Skip to content

Commit 73b85e4

Browse files
committed
Dynamically modify the context size without resetting the memory.
1 parent f471c74 commit 73b85e4

File tree

6 files changed

+36
-12
lines changed

6 files changed

+36
-12
lines changed

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ extern "C" {
490490
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
491491
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
492492
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
493-
LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params, const char* dump_file_path);
493+
LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params);
494494

495495
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
496496
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");

src/llama-context.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ void load_state(llama_context* ctx, const char* dump_file_path){
420420
}
421421
}
422422

423-
void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path = "dump_state.bin"){
423+
void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params){
424424
// Allow only to increase the context size.
425425
if (cparams.n_ctx < new_n_ctx) {
426426
cparams.n_ctx = new_n_ctx;
@@ -429,10 +429,14 @@ void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params, c
429429
/*.type_v =*/ params.type_v,
430430
};
431431

432+
/*
432433
// Resets the memory and sets it to new memory params with modified cparams
433434
dump_state(this, dump_file_path); // Dump the state here.
434435
memory.reset(model.create_memory(params_mem, cparams));
435436
load_state(this, dump_file_path); // Load the state.
437+
*/
438+
439+
memory.get()->resize(new_n_ctx);
436440
}
437441
else{
438442
LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__);
@@ -2293,8 +2297,8 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
22932297
return ctx->n_ctx();
22942298
}
22952299

2296-
void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params, const char* dump_file_path){
2297-
ctx->mod_n_ctx(new_n_ctx, params, dump_file_path);
2300+
void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params){
2301+
ctx->mod_n_ctx(new_n_ctx, params);
22982302
}
22992303

23002304
uint32_t llama_n_batch(const llama_context * ctx) {

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct llama_context {
3737

3838
ggml_context * get_ctx_compute() const;
3939

40-
void mod_n_ctx(uint32_t new_ctx, llama_context_params params, const char* dump_file_path);
40+
void mod_n_ctx(uint32_t new_ctx, llama_context_params params);
4141

4242
uint32_t n_ctx() const;
4343
uint32_t n_ctx_per_seq() const;

src/llama-kv-cache-unified.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,19 @@ uint32_t llama_kv_cache_unified::get_size() const {
704704
return cells.size();
705705
}
706706

707+
// Resizing the cells vector so we can have dynamic ctx.
708+
bool llama_kv_cache_unified::resize(uint32_t new_n_ctx){
709+
try{
710+
assert(new_n_ctx > cells.size());
711+
new_n_ctx = GGML_PAD(new_n_ctx, n_pad);
712+
cells.resize(new_n_ctx);
713+
return true;
714+
}
715+
catch (...){
716+
return false;
717+
}
718+
}
719+
707720
bool llama_kv_cache_unified::get_has_shift() const {
708721
return cells.get_has_shift();
709722
}

src/llama-kv-cache-unified.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class llama_kv_cache_unified : public llama_memory_i {
8888
//
8989

9090
uint32_t get_size() const;
91-
91+
// Resizing the cells size to get dynamic context size at runtime.
92+
bool resize(uint32_t);
9293
bool get_has_shift() const;
9394

9495
//

src/llama-memory.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,35 @@ struct llama_memory_i {
8080

8181
// getters
8282
virtual bool get_can_shift() const = 0;
83-
83+
8484
//
8585
// ops
8686
//
87-
87+
8888
// if data == true, the data buffers will also be cleared together with the metadata
8989
virtual void clear(bool data) = 0;
90-
90+
9191
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
9292
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
9393
virtual void seq_keep(llama_seq_id seq_id) = 0;
9494
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
9595
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
96-
96+
9797
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
9898
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
99-
99+
100100
//
101101
// state write/read
102102
//
103-
103+
104104
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
105105
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
106+
107+
// Dynamically modify the context files.
108+
virtual bool resize(uint32_t) {
109+
// Not implemented yet
110+
return false;
111+
};
106112
};
107113

108114
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;

0 commit comments

Comments
 (0)