Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params);

DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
Expand Down
44 changes: 44 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,46 @@ ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}

std::vector<uint8_t> dump_state(llama_context *ctx) {
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
return state_mem;
}

void load_state(llama_context* ctx, std::vector<uint8_t> state_mem){
if (llama_state_set_data(ctx, state_mem.data(), state_mem.size()) == 0) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);

// Free up resources
llama_free(ctx);
llama_free_model(const_cast<llama_model*>(&ctx->get_model()));
}
}

void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params){
// Allow only to increase the context size.
if (cparams.n_ctx < new_n_ctx) {
cparams.n_ctx = new_n_ctx;
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
};
// Resets the memory and sets it to new memory params with modified cparams
std::vector<uint8_t> state_memory = dump_state(this); // Dump the state here.
memory.reset(model.create_memory(params_mem, cparams));
load_state(this, state_memory); // Load the state.

// Frees the memory..
std::vector<uint8_t>().swap(state_memory);

// TODO: Resize the memory rather than re-creating the memory again
// memory.get()->resize(new_n_ctx);
}
else{
LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__);
}
}

uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
Expand Down Expand Up @@ -2327,6 +2367,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}

void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params){
ctx->mod_n_ctx(new_n_ctx, params);
}

uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ struct llama_context {

ggml_backend_sched_t get_sched() const;

ggml_context * get_ctx_compute() const;

void mod_n_ctx(uint32_t new_ctx, llama_context_params params);

uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
Expand Down
17 changes: 17 additions & 0 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,23 @@ uint32_t llama_kv_cache_unified::get_n_stream() const {
return n_stream;
}

// Resizing the cells vector so we can have dynamic ctx.
// Not modifying n_stream at the moment
bool llama_kv_cache_unified::resize(uint32_t new_n_ctx){
try{
new_n_ctx = GGML_PAD(new_n_ctx, n_pad);
// v_cells.resize(n_stream);
for (uint32_t s = 0; s < n_stream; ++s) {
assert(new_n_ctx > v_cells[s].size());
v_cells[s].resize(new_n_ctx);
}
return true;
}
catch (...){
return false;
}
}

bool llama_kv_cache_unified::get_has_shift() const {
bool result = false;

Expand Down
3 changes: 3 additions & 0 deletions src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class llama_kv_cache_unified : public llama_memory_i {
uint32_t get_size() const;
uint32_t get_n_stream() const;

// Resizing the cells size to get dynamic context size at runtime.
bool resize(uint32_t);

bool get_has_shift() const;

//
Expand Down
6 changes: 6 additions & 0 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ struct llama_memory_i {

virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;

// Dynamically modify the context files.
virtual bool resize(uint32_t) {
// Implemented only for unified memory at the moment.
return false;
}
};

using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
Expand Down