Skip to content

Commit 27456b5

Browse files
committed
Dynamically modify the context wthout resetting the memory.
1 parent 1e7b16f commit 27456b5

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

src/llama-context.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,13 @@ void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params, c
439439
/*.type_k =*/ params.type_k,
440440
/*.type_v =*/ params.type_v,
441441
};
442-
442+
/*
443443
// Resets the memory and sets it to new memory params with modified cparams
444444
dump_state(this, dump_file_path); // Dump the state here.
445445
memory.reset(model.create_memory(params_mem, cparams));
446446
load_state(this, dump_file_path); // Load the state.
447+
*/
448+
memory.get()->resize(new_n_ctx);
447449
}
448450
else{
449451
LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__);

src/llama-kv-cache-unified.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,23 @@ uint32_t llama_kv_cache_unified::get_n_stream() const {
10071007
return n_stream;
10081008
}
10091009

1010+
// Resizing the cells vector so we can have dynamic ctx.
1011+
// Not modifying n_stream at the moment
1012+
bool llama_kv_cache_unified::resize(uint32_t new_n_ctx){
1013+
try{
1014+
new_n_ctx = GGML_PAD(new_n_ctx, n_pad);
1015+
// v_cells.resize(n_stream);
1016+
for (uint32_t s = 0; s < n_stream; ++s) {
1017+
assert(new_n_ctx > v_cells[s].size());
1018+
v_cells[s].resize(new_n_ctx);
1019+
}
1020+
return true;
1021+
}
1022+
catch (...){
1023+
return false;
1024+
}
1025+
}
1026+
10101027
bool llama_kv_cache_unified::get_has_shift() const {
10111028
bool result = false;
10121029

src/llama-kv-cache-unified.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ class llama_kv_cache_unified : public llama_memory_i {
146146
uint32_t get_size() const;
147147
uint32_t get_n_stream() const;
148148

149+
// Resizing the cells size to get dynamic context size at runtime.
150+
bool resize(uint32_t);
151+
149152
bool get_has_shift() const;
150153

151154
//

src/llama-memory.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ struct llama_memory_i {
106106

107107
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
108108
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
109+
110+
// Dynamically modify the context files.
111+
virtual bool resize(uint32_t) {
112+
// Implemented only for unified memory at the moment.
113+
return false;
114+
}
109115
};
110116

111117
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;

0 commit comments

Comments
 (0)