diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c5c450e24..77fe2b34d 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -76,6 +76,7 @@ add_library(${TARGET} STATIC minja.hpp ngram-cache.h ngram-cache.cpp + speculative.cpp ) if (BUILD_SHARED_LIBS) diff --git a/common/common.cpp b/common/common.cpp index ed50a0098..b3238cd8a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -505,6 +505,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_ctx = std::stoi(argv[i]); return true; } + if (arg == "-cd" || arg == "--ctx-size-draft") { + CHECK_ARG + params.n_ctx_draft = std::stoi(argv[i]); + return true; + } if (arg == "--grp-attn-n" || arg == "-gan") { CHECK_ARG params.grp_attn_n = std::stoi(argv[i]); @@ -725,7 +730,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } } return true; - } + } if (arg == "--cfg-negative-prompt") { CHECK_ARG sparams.cfg_negative_prompt = argv[i]; @@ -765,11 +770,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_keep = std::stoi(argv[i]); return true; } - if (arg == "--draft") { + if (arg == "--draft" || arg == "--draft-max" || arg == "--draft-n") { CHECK_ARG params.n_draft = std::stoi(argv[i]); return true; } + if (arg == "--draft-min" || arg == "--draft-n-min") { + CHECK_ARG + params.n_draft_min = std::stoi(argv[i]); + return true; + } + if (arg == "--draft-p-min") { + CHECK_ARG + params.p_draft_min = std::stof(argv[i]); + return true; + } if (arg == "--chunks") { CHECK_ARG params.n_chunks = std::stoi(argv[i]); @@ -934,6 +949,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cache_type_v = argv[++i]; return true; } + if (arg == "-ctkd" || arg == "--cache-type-k-draft") { + params.cache_type_k_draft = argv[++i]; + return true; + } + if (arg == "-ctvd" || arg == "--cache-type-v-draft") { + params.cache_type_v_draft = argv[++i]; + return true; + } if (arg == "-mli" || arg == "--multiline-input") { params.multiline_input = true; return true; @@ -1071,7 +1094,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa size_t pos = 0; while ((pos = servers.find(",")) != std::string::npos) { std::string server = servers.substr(0, pos); - ggml_backend_rpc_buffer_type(server.c_str()); + ggml_backend_rpc_buffer_type(server.c_str()); servers.erase(0, pos + 1); } ggml_backend_rpc_buffer_type(servers.c_str()); @@ -1693,7 +1716,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); options.push_back({ "speculative", "-tbd, --threads-batch-draft N", "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); - options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)" }); @@ -1701,6 +1723,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "path to dynamic lookup cache to use for lookup decoding (updated by generation)" }); options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); + options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); @@ -1811,6 +1834,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" }); options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() }); options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() }); + options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" }); + options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" }); options.push_back({ "perplexity" }); options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" }); @@ -1893,6 +1918,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); + options.push_back({ "*", "--draft-max, --draft, --draft-n N", + "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); + options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); + options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.p_draft_min }); options.push_back({ "retrieval" }); options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); @@ -2052,7 +2081,7 @@ std::string string_join(const std::vector & strs, const std::string if (strs.empty()) { return ""; } - + std::ostringstream oss; for (size_t i = 0; i < strs.size(); ++i) { if (i > 0) { diff --git a/common/common.h b/common/common.h index 4c62a53e2..f2a658d16 100644 --- a/common/common.h +++ b/common/common.h @@ -83,10 +83,13 @@ struct gpt_params { int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size + int32_t n_ctx_draft = 0; // context size for draft model int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding + float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode @@ -207,6 +210,8 @@ struct gpt_params { std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V + std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model + std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector diff --git a/common/sampling.cpp b/common/sampling.cpp index 08a19b457..526a47ebc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl( cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + ctx_sampling->cur_p = { cur.data(), cur.size(), false }; + + llama_token_data_array & cur_p = ctx_sampling->cur_p; // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; @@ -506,3 +508,47 @@ void llama_sampling_accept( llama_sampler_dry_accept(ctx_sampling->smpl, id); } } + +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) { + return &ctx_sampling->cur_p; +} + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return llama_sampling_sample_and_accept_n(gsmpl, ctx, idxs, draft); +} + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + } + + return result; +} + diff --git a/common/sampling.h b/common/sampling.h index 1d5bf0b9e..27401145c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -101,6 +101,8 @@ struct llama_sampling_context { size_t n_valid; // Number of correct top tokens with correct probabilities. + llama_token_data_array cur_p; // current candidates + std::mt19937 rng; }; @@ -176,3 +178,11 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// returns at least 1 token, up to draft.size() +// access the internal list of current candidate tokens +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling); + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 000000000..d70a278ad --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,275 @@ +#include "speculative.h" + +#include "common.h" +#include "sampling.h" +#include "llama-impl.h" + +#include +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct llama_speculative { + struct llama_context * ctx; + struct llama_sampling_context * smpl; + + llama_batch batch; + std::vector prompt; +}; + +struct llama_speculative * llama_speculative_init( + struct llama_context * ctx_dft) { + auto * result = new llama_speculative { + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + llama_sampling_params params; + params.no_perf = false; + + params.top_k = 40; + params.top_p = 0.9; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = llama_sampler_init(llama_get_model(ctx_dft), params); + } +#else + { + llama_sampling_params params; + params.top_k = 10; + params.samplers_sequence = { + llama_sampler_type::TOP_K, + }; + const auto *model_dft = llama_get_model(ctx_dft); + result->smpl = llama_sampling_init(llama_get_model_vocab(model_dft), params); + } +#endif + + return result; +} + +void llama_speculative_free(struct llama_speculative * spec) { + if (spec == nullptr) { + return; + } + + llama_sampling_free(spec->smpl); + + llama_batch_free(spec->batch); + + delete spec; +} + +bool llama_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const struct llama_vocab * vocab_tgt = llama_get_model_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_get_model_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LLAMA_LOG_INFO("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LLAMA_LOG_INFO("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LLAMA_LOG_ERROR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { + LLAMA_LOG_ERROR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LLAMA_LOG_ERROR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); + LLAMA_LOG_ERROR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); + return false; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + + const int model_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (model_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LLAMA_LOG_ERROR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, n_vocab_dft, model_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LLAMA_LOG_ERROR("%s: draft vocab vocab must match target vocab to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + llama_token_to_piece(ctx_tgt, i).c_str(), + llama_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, + const std::vector & prompt_tgt, + llama_token id_last) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + // LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + std::vector result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + llama_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LLAMA_LOG_INFO("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LLAMA_LOG_INFO("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx, batch); + } + + const llama_pos n_past = prompt.size(); + + // LLAMA_LOG_INFO("%s: n_past = %d\n", __func__, n_past); + + llama_batch_clear(batch); + llama_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LLAMA_LOG_INFO("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode(ctx, batch); + + llama_sampling_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + llama_batch_clear(batch); + + llama_sampling_sample(smpl, ctx, nullptr, 0); + + const auto * cur_p = llama_sampling_get_candidates(smpl); + + // for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + // LLAMA_LOG_INFO(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + // k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); + // } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + llama_sampling_accept(smpl, ctx, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + llama_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx, batch); + + prompt.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 000000000..faa6ee542 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,29 @@ +#pragma once + +#include "llama.h" + +#include + +struct llama_speculative; + +struct llama_speculative_params { + int n_draft = 16; // max drafted tokens + int n_reuse = 256; + + float p_min = 0.75f; // min probability required to accept a token in the draft +}; + +struct llama_speculative * llama_speculative_init(struct llama_context * ctx_dft); + +void llama_speculative_free(struct llama_speculative * spec); + +bool llama_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + +// sample up to n_draft tokens and add them to the batch using the draft model +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, + const std::vector & prompt, + llama_token id_last); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f7de82006..7d29149d8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,6 +2,8 @@ #include "utils.hpp" #include "common.h" +#include "speculative.h" +#include "sampling.h" #include "json-schema-to-grammar.h" #include "llama.h" #include "grammar-parser.h" @@ -148,14 +150,14 @@ static std::string remove_simple_function_calls(const std::string& content) { size_t pos = 0; while ((pos = cleaned.find(func_pattern, pos)) != std::string::npos) { size_t func_start = pos; - + // Find the opening brace for arguments size_t brace_pos = cleaned.find('{', pos); if (brace_pos == std::string::npos) { pos += func_pattern.length(); continue; } - + // Find the matching closing brace int brace_count = 1; size_t end_pos = brace_pos + 1; @@ -164,7 +166,7 @@ static std::string remove_simple_function_calls(const std::string& content) { else if (cleaned[end_pos] == '}') brace_count--; end_pos++; } - + if (brace_count == 0) { // Remove the entire function call cleaned.erase(func_start, end_pos - func_start); @@ -186,7 +188,7 @@ static std::string remove_xml_function_calls(const std::string& content) { pos = tool_call_start + 11; continue; } - + // Remove the entire XML tool call block cleaned.erase(tool_call_start, tool_call_end - tool_call_start + 12); pos = tool_call_start; @@ -196,17 +198,17 @@ static std::string remove_xml_function_calls(const std::string& content) { static std::string clean_all_function_call_formats(const std::string& content) { std::string cleaned = content; - + // Remove XML format first cleaned = remove_xml_function_calls(cleaned); - + // Then remove simple format cleaned = remove_simple_function_calls(cleaned); - + // Trim whitespace from cleaned content cleaned.erase(0, cleaned.find_first_not_of(" \t\n\r")); cleaned.erase(cleaned.find_last_not_of(" \t\n\r") + 1); - + return cleaned; } @@ -230,6 +232,13 @@ struct slot_params { bool timings_per_token = false; json input_prefix; json input_suffix; + + // speculative decoding parameters + struct { + int n_max = 16; // max drafted tokens + int n_min = 0; // min drafted tokens to accept + float p_min = 0.75f; // min probability required to accept a token in the draft + } speculative; }; struct server_slot { @@ -293,6 +302,15 @@ struct server_slot { int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + // speculative decoding + struct llama_speculative * spec = nullptr; + llama_context * ctx_dft = nullptr; + llama_batch batch_spec = {}; + + // speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + int32_t n_past_se = 0; // self-extend // stats @@ -321,28 +339,32 @@ struct server_slot { n_past_se = 0; generated_token_probs.clear(); - + // Reset streaming tool call state previous_msg = ik_chat_msg(); current_msg = ik_chat_msg(); tool_call_ids.clear(); + + // Reset speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } // Update chat message and compute diffs for streaming tool calls // Based on original llama.cpp update_chat_msg pattern const ik_chat_msg & update_chat_msg(std::vector & diffs) { ik_chat_msg previous = current_msg; - + try { // Parse generated text incrementally (is_partial = true during generation) bool is_partial = !stopped_eos && !stopped_word && !stopped_limit; ik_chat_msg new_msg = parse_chat_message_incremental(generated_text, is_partial, oaicompat_model); - + if (!new_msg.empty()) { // Ensure tool call IDs are set consistently across streaming chunks new_msg.ensure_tool_call_ids_set(tool_call_ids, generate_tool_call_id); current_msg = new_msg; - + // Compute diffs for streaming diffs = ik_chat_msg_diff::compute_diffs(previous, current_msg); } @@ -350,7 +372,7 @@ struct server_slot { // If parsing fails, don't update current_msg and return empty diffs diffs.clear(); } - + return current_msg; } @@ -413,17 +435,17 @@ struct server_slot { timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - + timings.predicted_n = n_decoded; timings.predicted_ms = (ggml_time_us() - t_start_generation) / 1e3; timings.predicted_per_token_ms = t_token_generation / n_decoded; timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - //// Add speculative metrics - //if (n_draft_total > 0) { - // timings.draft_n = n_draft_total; - // timings.draft_n_accepted = n_draft_accepted; - //} + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } return timings; } @@ -797,6 +819,11 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; + // For speculative decoding + llama_model * model_draft = nullptr; + llama_context * ctx_draft = nullptr; + llama_context_params cparams_dft; + int32_t n_ctx; // total context for all clients / slots // system prompt @@ -829,11 +856,28 @@ struct server_context { model = nullptr; } + // Free draft model and context if they exist + if (ctx_draft) { + llama_free(ctx_draft); + ctx_draft = nullptr; + } + if (model_draft) { + llama_free_model(model_draft); + model_draft = nullptr; + } + // Clear any sampling context for (server_slot & slot : slots) { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } + if (slot.ctx_dft) { + llama_free(slot.ctx_dft); + } + if (slot.spec) { + llama_speculative_free(slot.spec); + } + llama_batch_free(slot.batch_spec); } llama_batch_free(batch); @@ -869,6 +913,41 @@ struct server_context { chat_templates = llama_chat_templates_from_model(model, params.chat_template); } GGML_ASSERT(chat_templates.template_default.get() != nullptr); + + // Load draft model for speculative decoding if specified + if (!params.model_draft.empty()) { + LOG_INFO("loading draft model", {{"model", params.model_draft}}); + + gpt_params params_dft; + params_dft.model = params.model_draft; + params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft; + params_dft.n_gpu_layers = params.n_gpu_layers_draft; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; + params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; + params_dft.flash_attn = params.flash_attn; + + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); + + llama_model * model_dft = llama_init_dft.model; + if (model_dft == nullptr) { + LOG_ERROR("failed to load draft model", {{"model", params.model_draft}}); + return false; + } + + if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) { + LOG_ERROR("the draft model is not compatible with the target model", {}); + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); + + cparams_dft = llama_context_params_from_gpt_params(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + model_draft = llama_init_dft.model; + ctx_draft = llama_init_dft.context; + } return true; } @@ -943,6 +1022,23 @@ struct server_context { slot.sparams = params.sparams; + // Initialize speculative decoding if a draft model is loaded + if (ctx_draft) { + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + + slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); + if (slot.ctx_dft == nullptr) { + LOG_ERROR("failed to create draft context", {}); + return; + } + + slot.spec = llama_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + LOG_ERROR("failed to create speculator", {}); + return; + } + } + slot.reset(); slots.push_back(slot); @@ -1134,6 +1230,16 @@ struct server_context { slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + // speculative decoding parameters + slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min); + + // Clamp speculative parameters + slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); + slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); + slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); + if (slot.sparams.penalty_last_n < -1) { throw std::runtime_error("Error: repeat_last_n must be >= -1"); } @@ -2737,6 +2843,118 @@ struct server_context { slot.i_batch = -1; } + + // Do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.spec) { + continue; + } + + if (slot.state != SLOT_STATE_PROCESSING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_predict > 0) { + n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1); + } + + LOG_VERBOSE("max possible draft", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max} + }); + + if (n_draft_max < slot.params.speculative.n_min) { + LOG_VERBOSE("the max possible draft is too small", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max}, + {"n_min", slot.params.speculative.n_min} + }); + continue; + } + + llama_token id = slot.sampled; + + struct llama_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const std::vector & cached_text_tokens = slot.cache_tokens; + std::vector draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + LOG_VERBOSE("ignoring small draft", { + {"id_slot", slot.id}, + {"draft_size", (int) draft.size()}, + {"n_min", slot.params.speculative.n_min} + }); + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + llama_batch_clear(slot.batch_spec); + llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true); + } + + LOG_VERBOSE("decoding speculative batch", { + {"id_slot", slot.id}, + {"size", slot.batch_spec.n_tokens} + }); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + std::vector ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special); + // result.prob = 1.0f; // set later + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + LOG_VERBOSE("speculative decoding result", { + {"id_slot", slot.id}, + {"accepted", (int) ids.size() - 1}, + {"total", (int) draft.size()}, + {"new_n_past", slot.n_past} + }); + } } LOG_VERBOSE("run slots completed", {}); @@ -2763,10 +2981,10 @@ static json format_final_response_oaicompat(const json& request, json result, co // Parse tool calls using model-specific format detection std::string model_name = json_value(request, "model", std::string("")); - + // Use the same parsing logic as streaming path for consistency ik_chat_msg parsed_msg = parse_chat_message_incremental(content, false, model_name); - + // Convert to JSON format for compatibility json tool_calls = json::array(); for (const auto & tc : parsed_msg.tool_calls) { @@ -2779,9 +2997,9 @@ static json format_final_response_oaicompat(const json& request, json result, co {"id", tc.id} }); } - + bool has_tool_calls = !tool_calls.empty(); - + // Use cleaned content from parser (following original llama.cpp pattern) if (has_tool_calls) { content = parsed_msg.content; // Parser already cleaned the content @@ -2863,14 +3081,14 @@ static std::vector format_partial_response_oaicompat(server_task_result ta // Use generated_text (complete content) for finish_reason logic, not content (empty in streaming) std::string generated_text = json_value(result, "generated_text", std::string("")); ik_chat_msg final_msg = parse_chat_message_incremental(generated_text, false, modelname); - + // Debug logging LOG_INFO("DEBUG: Streaming finish_reason check", { {"generated_text", generated_text}, - {"model_name", modelname}, + {"model_name", modelname}, {"tool_calls_count", final_msg.tool_calls.size()} }); - + finish_reason = final_msg.tool_calls.empty() ? "stop" : "tool_calls"; } @@ -2878,18 +3096,18 @@ static std::vector format_partial_response_oaicompat(server_task_result ta // Follow original llama.cpp pattern: Always process diffs and add final chunk std::vector streaming_chunks; - + // Extract diffs from task result (populated by send_partial_response) // Following original llama.cpp pattern where diffs are stored in task result std::vector diffs; - + if (result.contains("oaicompat_msg_diffs") && result["oaicompat_msg_diffs"].is_array()) { for (const auto & diff_json : result["oaicompat_msg_diffs"]) { ik_chat_msg_diff diff; - + // Extract content delta diff.content_delta = diff_json.value("content_delta", ""); - + // Extract tool call data if (diff_json.contains("tool_call_index")) { diff.tool_call_index = diff_json["tool_call_index"]; @@ -2902,13 +3120,13 @@ static std::vector format_partial_response_oaicompat(server_task_result ta } else { diff.tool_call_index = std::string::npos; } - + diffs.push_back(diff); } } - + streaming_chunks = generate_streaming_chunks(diffs, completion_id, modelname); - + // Always add final chunk (like original llama.cpp) if (!finish_reason.empty()) { json finish_chunk = { @@ -2922,6 +3140,7 @@ static std::vector format_partial_response_oaicompat(server_task_result ta }; streaming_chunks.push_back(finish_chunk); } + if (server_task_result_dict.count(task_result.id) > 0) { for (auto& chunk : streaming_chunks) @@ -3092,7 +3311,7 @@ int main(int argc, char ** argv) { // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0; - + // struct that contains llama context and inference server_context ctx_server;