diff --git a/include/llama.h b/include/llama.h index 545e957e5f52b..f1e1baaa6a837 100644 --- a/include/llama.h +++ b/include/llama.h @@ -457,6 +457,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API void llama_mod_n_ctx (struct llama_context * ctx, uint32_t new_ctx, struct llama_context_params params); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8db..c38ad07027004 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -412,6 +412,46 @@ ggml_backend_sched_t llama_context::get_sched() const { return sched.get(); } +std::vector dump_state(llama_context *ctx) { + std::vector state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); + return state_mem; +} + +void load_state(llama_context* ctx, std::vector state_mem){ + if (llama_state_set_data(ctx, state_mem.data(), state_mem.size()) == 0) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + + // Free up resources + llama_free(ctx); + llama_free_model(const_cast(&ctx->get_model())); + } +} + +void llama_context::mod_n_ctx(uint32_t new_n_ctx, llama_context_params params){ + // Allow only to increase the context size. + if (cparams.n_ctx < new_n_ctx) { + cparams.n_ctx = new_n_ctx; + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + }; + // Resets the memory and sets it to new memory params with modified cparams + std::vector state_memory = dump_state(this); // Dump the state here. + memory.reset(model.create_memory(params_mem, cparams)); + load_state(this, state_memory); // Load the state. + + // Frees the memory.. + std::vector().swap(state_memory); + + // TODO: Resize the memory rather than re-creating the memory again + // memory.get()->resize(new_n_ctx); + } + else{ + LLAMA_LOG_ERROR("%s: Cannot decrease the context size.", __func__); + } +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -2327,6 +2367,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) { return ctx->n_ctx(); } +void llama_mod_n_ctx(struct llama_context * ctx, uint32_t new_n_ctx, llama_context_params params){ + ctx->mod_n_ctx(new_n_ctx, params); +} + uint32_t llama_n_batch(const llama_context * ctx) { return ctx->n_batch(); } diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56dfb2..927a857c22192 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -35,6 +35,10 @@ struct llama_context { ggml_backend_sched_t get_sched() const; + ggml_context * get_ctx_compute() const; + + void mod_n_ctx(uint32_t new_ctx, llama_context_params params); + uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; uint32_t n_batch() const; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8cd..83f9fe4e0ff31 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1011,6 +1011,23 @@ uint32_t llama_kv_cache_unified::get_n_stream() const { return n_stream; } +// Resizing the cells vector so we can have dynamic ctx. +// Not modifying n_stream at the moment +bool llama_kv_cache_unified::resize(uint32_t new_n_ctx){ + try{ + new_n_ctx = GGML_PAD(new_n_ctx, n_pad); + // v_cells.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + assert(new_n_ctx > v_cells[s].size()); + v_cells[s].resize(new_n_ctx); + } + return true; + } + catch (...){ + return false; + } +} + bool llama_kv_cache_unified::get_has_shift() const { bool result = false; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962e2a..d7385fd0c1337 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -146,6 +146,9 @@ class llama_kv_cache_unified : public llama_memory_i { uint32_t get_size() const; uint32_t get_n_stream() const; + // Resizing the cells size to get dynamic context size at runtime. + bool resize(uint32_t); + bool get_has_shift() const; // diff --git a/src/llama-memory.h b/src/llama-memory.h index e8ba336e8525d..948fb72f67f3a 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -106,6 +106,12 @@ struct llama_memory_i { virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; + + // Dynamically modify the context files. + virtual bool resize(uint32_t) { + // Implemented only for unified memory at the moment. + return false; + } }; using llama_memory_ptr = std::unique_ptr;