From de5ecab4fb02cc9185cbc61a9e92c77f939c4582 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 02:51:00 +0000 Subject: [PATCH 01/27] server : integrate speculative decoding --- common/speculative.cpp | 280 +++++++++++++++++++++++++++++++++++++ common/speculative.h | 28 ++++ examples/server/server.cpp | 230 +++++++++++++++++++++++++++++- 3 files changed, 533 insertions(+), 5 deletions(-) create mode 100644 common/speculative.cpp create mode 100644 common/speculative.h diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 000000000..843bd1ddb --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,280 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct common_speculative { + struct llama_context * ctx; + struct common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt; +}; + +struct common_speculative * common_speculative_init( + struct llama_context * ctx_dft) { + auto * result = new common_speculative { + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 40; + params.top_p = 0.9; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#else + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 10; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#endif + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + if (spec == nullptr) { + return; + } + + common_sampler_free(spec->smpl); + + llama_batch_free(spec->batch); + + delete spec; +} + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(vocab_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { + LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); + LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + return false; + } + + { + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + auto * mem = llama_get_memory(ctx); + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_memory_clear(mem, false); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_memory_seq_rm (mem, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx, batch); + } + + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode(ctx, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx, batch); + + prompt.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 000000000..75f2e3112 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,28 @@ +#pragma once + +#include "llama.h" +#include "common.h" + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; // max drafted tokens + int n_reuse = 256; + + float p_min = 0.75f; // min probability required to accept a token in the draft +}; + +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); + +void common_speculative_free(struct common_speculative * spec); + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + +// sample up to n_draft tokens and add them to the batch using the draft model +std::vector common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const std::vector & prompt, + llama_token id_last); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 42f0b17bd..a4c1e992c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -230,6 +230,13 @@ struct slot_params { bool timings_per_token = false; json input_prefix; json input_suffix; + + // speculative decoding parameters + struct { + int n_max = 0; // max drafted tokens + int n_min = 0; // min drafted tokens to accept + float p_min = 0.75f; // min probability required to accept a token in the draft + } speculative; }; struct server_slot { @@ -292,6 +299,15 @@ struct server_slot { int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + + // speculative decoding + struct common_speculative * spec = nullptr; + llama_context * ctx_dft = nullptr; + llama_batch batch_spec = {}; + + // speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted int32_t n_past_se = 0; // self-extend @@ -326,6 +342,10 @@ struct server_slot { previous_msg = ik_chat_msg(); current_msg = ik_chat_msg(); tool_call_ids.clear(); + + // Reset speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } // Update chat message and compute diffs for streaming tool calls @@ -419,11 +439,11 @@ struct server_slot { timings.predicted_per_token_ms = t_token_generation / n_decoded; timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - //// Add speculative metrics - //if (n_draft_total > 0) { - // timings.draft_n = n_draft_total; - // timings.draft_n_accepted = n_draft_accepted; - //} + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } return timings; } @@ -796,6 +816,10 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; + + // For speculative decoding + llama_init_result model_dft_owned; + llama_context_params cparams_dft; int32_t n_ctx; // total context for all clients / slots @@ -833,6 +857,13 @@ struct server_context { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } + if (slot.ctx_dft) { + llama_free(slot.ctx_dft); + } + if (slot.spec) { + common_speculative_free(slot.spec); + } + llama_batch_free(slot.batch_spec); } llama_batch_free(batch); @@ -860,6 +891,56 @@ struct server_context { add_bos_token = llama_should_add_bos_token(model); GGML_ASSERT(llama_add_eos_token(model) != 1); + // Load draft model for speculative decoding if specified + if (!params.speculative_model.empty()) { + LOG_INFO("loading draft model", {{"model", params.speculative_model}}); + + gpt_params params_dft = params; + params_dft.model = params.speculative_model; + params_dft.n_ctx = params.speculative_n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative_n_ctx; + params_dft.n_gpu_layers = params.speculative_n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params.speculative_cache_type_k; + params_dft.cache_type_v = params.speculative_cache_type_v; + + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); + + llama_model * model_dft = llama_init_dft.model; + if (model_dft == nullptr) { + LOG_ERROR("failed to load draft model", {{"model", params.speculative_model}}); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { + LOG_ERROR("the draft model is not compatible with the target model", {}); + return false; + } + + // Store the draft context initialization parameters for later use + cparams_dft = llama_context_default_params(); + cparams_dft.n_ctx = params_dft.n_ctx; + cparams_dft.n_batch = cparams_dft.n_ctx; + cparams_dft.n_ubatch = params_dft.n_ubatch; + cparams_dft.freq_base = params_dft.rope_freq_base; + cparams_dft.freq_scale = params_dft.rope_freq_scale; + cparams_dft.yarn_ext_factor = params_dft.yarn_ext_factor; + cparams_dft.yarn_attn_factor = params_dft.yarn_attn_factor; + cparams_dft.yarn_beta_fast = params_dft.yarn_beta_fast; + cparams_dft.yarn_beta_slow = params_dft.yarn_beta_slow; + cparams_dft.yarn_orig_ctx = params_dft.yarn_orig_ctx; + cparams_dft.clip_kqv = params_dft.clip_kqv; + cparams_dft.pooling_type = params_dft.pooling_type; + cparams_dft.defrag_thold = params_dft.defrag_thold; + cparams_dft.type_k = params_dft.type_k; + cparams_dft.type_v = params_dft.type_v; + cparams_dft.logits_all = false; + cparams_dft.embedding = false; + cparams_dft.offload_kqv = params_dft.offload_kqv; + + // Keep the draft model alive + model_dft_owned = llama_init_dft; + } + return true; } @@ -909,6 +990,23 @@ struct server_context { slot.ga_w = ga_w; slot.sparams = params.sparams; + + // Initialize speculative decoding if a draft model is loaded + if (model_dft_owned.context) { + slot.batch_spec = llama_batch_init(params.speculative_n_max + 1, 0, 1); + + slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft); + if (slot.ctx_dft == nullptr) { + LOG_ERROR("failed to create draft context", {}); + return; + } + + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + LOG_ERROR("failed to create speculator", {}); + return; + } + } slot.reset(); @@ -1100,6 +1198,16 @@ struct server_context { slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + // speculative decoding parameters + slot.params.speculative.n_max = json_value(data, "speculative.n_max", 0); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", 0); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", 0.75f); + + // Clamp speculative parameters + slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); + slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); + slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); if (slot.sparams.penalty_last_n < -1) { throw std::runtime_error("Error: repeat_last_n must be >= -1"); @@ -2704,6 +2812,118 @@ struct server_context { slot.i_batch = -1; } + + // Do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.spec) { + continue; + } + + if (slot.state != SLOT_STATE_PROCESSING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_predict > 0) { + n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1); + } + + LOG_VERBOSE("max possible draft", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max} + }); + + if (n_draft_max < slot.params.speculative.n_min) { + LOG_VERBOSE("the max possible draft is too small", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max}, + {"n_min", slot.params.speculative.n_min} + }); + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const std::vector & cached_text_tokens = slot.cache_tokens; + std::vector draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + LOG_VERBOSE("ignoring small draft", { + {"id_slot", slot.id}, + {"draft_size", (int) draft.size()}, + {"n_min", slot.params.speculative.n_min} + }); + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + llama_batch_clear(slot.batch_spec); + llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true); + } + + LOG_VERBOSE("decoding speculative batch", { + {"id_slot", slot.id}, + {"size", slot.batch_spec.n_tokens} + }); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + std::vector ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special); + result.prob = 1.0f; // set later + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + LOG_VERBOSE("speculative decoding result", { + {"id_slot", slot.id}, + {"accepted", (int) ids.size() - 1}, + {"total", (int) draft.size()}, + {"new_n_past", slot.n_past} + }); + } } LOG_VERBOSE("run slots completed", {}); From 98f6a48e68a88757cc4ba2d92c6cdc69a3d82250 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 02:58:12 +0000 Subject: [PATCH 02/27] server: Fix field names --- examples/server/server.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a4c1e992c..0b5ca1f16 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -892,16 +892,16 @@ struct server_context { GGML_ASSERT(llama_add_eos_token(model) != 1); // Load draft model for speculative decoding if specified - if (!params.speculative_model.empty()) { - LOG_INFO("loading draft model", {{"model", params.speculative_model}}); + if (!params.model_draft.empty()) { + LOG_INFO("loading draft model", {{"model", params.model_draft}}); gpt_params params_dft = params; - params_dft.model = params.speculative_model; - params_dft.n_ctx = params.speculative_n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative_n_ctx; - params_dft.n_gpu_layers = params.speculative_n_gpu_layers; + params_dft.model = params.model_draft; + params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_gpu_layers_draft; + params_dft.n_gpu_layers = params.n_gpu_layers_draft; params_dft.n_parallel = 1; - params_dft.cache_type_k = params.speculative_cache_type_k; - params_dft.cache_type_v = params.speculative_cache_type_v; + params_dft.cache_type_k = params.cache_type_k; + params_dft.cache_type_v = params.cache_type_v; llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); @@ -993,7 +993,7 @@ struct server_context { // Initialize speculative decoding if a draft model is loaded if (model_dft_owned.context) { - slot.batch_spec = llama_batch_init(params.speculative_n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft); if (slot.ctx_dft == nullptr) { From 80a0579a292066820e53a19ce002d5c143393c75 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 03:00:35 +0000 Subject: [PATCH 03/27] server: fix include, whitespace --- examples/server/server.cpp | 97 +++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0b5ca1f16..d724bcf09 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,6 +2,7 @@ #include "utils.hpp" #include "common.h" +#include "speculative.h" #include "json-schema-to-grammar.h" #include "llama.h" #include "grammar-parser.h" @@ -148,14 +149,14 @@ static std::string remove_simple_function_calls(const std::string& content) { size_t pos = 0; while ((pos = cleaned.find(func_pattern, pos)) != std::string::npos) { size_t func_start = pos; - + // Find the opening brace for arguments size_t brace_pos = cleaned.find('{', pos); if (brace_pos == std::string::npos) { pos += func_pattern.length(); continue; } - + // Find the matching closing brace int brace_count = 1; size_t end_pos = brace_pos + 1; @@ -164,7 +165,7 @@ static std::string remove_simple_function_calls(const std::string& content) { else if (cleaned[end_pos] == '}') brace_count--; end_pos++; } - + if (brace_count == 0) { // Remove the entire function call cleaned.erase(func_start, end_pos - func_start); @@ -186,7 +187,7 @@ static std::string remove_xml_function_calls(const std::string& content) { pos = tool_call_start + 11; continue; } - + // Remove the entire XML tool call block cleaned.erase(tool_call_start, tool_call_end - tool_call_start + 12); pos = tool_call_start; @@ -196,17 +197,17 @@ static std::string remove_xml_function_calls(const std::string& content) { static std::string clean_all_function_call_formats(const std::string& content) { std::string cleaned = content; - + // Remove XML format first cleaned = remove_xml_function_calls(cleaned); - + // Then remove simple format cleaned = remove_simple_function_calls(cleaned); - + // Trim whitespace from cleaned content cleaned.erase(0, cleaned.find_first_not_of(" \t\n\r")); cleaned.erase(cleaned.find_last_not_of(" \t\n\r") + 1); - + return cleaned; } @@ -230,7 +231,7 @@ struct slot_params { bool timings_per_token = false; json input_prefix; json input_suffix; - + // speculative decoding parameters struct { int n_max = 0; // max drafted tokens @@ -299,12 +300,12 @@ struct server_slot { int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width - + // speculative decoding struct common_speculative * spec = nullptr; llama_context * ctx_dft = nullptr; llama_batch batch_spec = {}; - + // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted @@ -337,12 +338,12 @@ struct server_slot { n_past_se = 0; generated_token_probs.clear(); - + // Reset streaming tool call state previous_msg = ik_chat_msg(); current_msg = ik_chat_msg(); tool_call_ids.clear(); - + // Reset speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; @@ -352,17 +353,17 @@ struct server_slot { // Based on original llama.cpp update_chat_msg pattern const ik_chat_msg & update_chat_msg(std::vector & diffs) { ik_chat_msg previous = current_msg; - + try { // Parse generated text incrementally (is_partial = true during generation) bool is_partial = !stopped_eos && !stopped_word && !stopped_limit; ik_chat_msg new_msg = parse_chat_message_incremental(generated_text, is_partial, oaicompat_model); - + if (!new_msg.empty()) { // Ensure tool call IDs are set consistently across streaming chunks new_msg.ensure_tool_call_ids_set(tool_call_ids, generate_tool_call_id); current_msg = new_msg; - + // Compute diffs for streaming diffs = ik_chat_msg_diff::compute_diffs(previous, current_msg); } @@ -370,7 +371,7 @@ struct server_slot { // If parsing fails, don't update current_msg and return empty diffs diffs.clear(); } - + return current_msg; } @@ -433,7 +434,7 @@ struct server_slot { timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - + timings.predicted_n = n_decoded; timings.predicted_ms = (ggml_time_us() - t_start_generation) / 1e3; timings.predicted_per_token_ms = t_token_generation / n_decoded; @@ -816,7 +817,7 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; - + // For speculative decoding llama_init_result model_dft_owned; llama_context_params cparams_dft; @@ -894,7 +895,7 @@ struct server_context { // Load draft model for speculative decoding if specified if (!params.model_draft.empty()) { LOG_INFO("loading draft model", {{"model", params.model_draft}}); - + gpt_params params_dft = params; params_dft.model = params.model_draft; params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_gpu_layers_draft; @@ -902,20 +903,20 @@ struct server_context { params_dft.n_parallel = 1; params_dft.cache_type_k = params.cache_type_k; params_dft.cache_type_v = params.cache_type_v; - + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); - + llama_model * model_dft = llama_init_dft.model; if (model_dft == nullptr) { - LOG_ERROR("failed to load draft model", {{"model", params.speculative_model}}); + LOG_ERROR("failed to load draft model", {{"model", params.model_draft}}); return false; } - + if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { LOG_ERROR("the draft model is not compatible with the target model", {}); return false; } - + // Store the draft context initialization parameters for later use cparams_dft = llama_context_default_params(); cparams_dft.n_ctx = params_dft.n_ctx; @@ -936,7 +937,7 @@ struct server_context { cparams_dft.logits_all = false; cparams_dft.embedding = false; cparams_dft.offload_kqv = params_dft.offload_kqv; - + // Keep the draft model alive model_dft_owned = llama_init_dft; } @@ -990,17 +991,17 @@ struct server_context { slot.ga_w = ga_w; slot.sparams = params.sparams; - + // Initialize speculative decoding if a draft model is loaded if (model_dft_owned.context) { slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); - + slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft); if (slot.ctx_dft == nullptr) { LOG_ERROR("failed to create draft context", {}); return; } - + slot.spec = common_speculative_init(slot.ctx_dft); if (slot.spec == nullptr) { LOG_ERROR("failed to create speculator", {}); @@ -1198,12 +1199,12 @@ struct server_context { slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - + // speculative decoding parameters slot.params.speculative.n_max = json_value(data, "speculative.n_max", 0); slot.params.speculative.n_min = json_value(data, "speculative.n_min", 0); slot.params.speculative.p_min = json_value(data, "speculative.p_min", 0.75f); - + // Clamp speculative parameters slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); @@ -2812,7 +2813,7 @@ struct server_context { slot.i_batch = -1; } - + // Do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.spec) { @@ -2950,10 +2951,10 @@ static json format_final_response_oaicompat(const json& request, json result, co // Parse tool calls using model-specific format detection std::string model_name = json_value(request, "model", std::string("")); - + // Use the same parsing logic as streaming path for consistency ik_chat_msg parsed_msg = parse_chat_message_incremental(content, false, model_name); - + // Convert to JSON format for compatibility json tool_calls = json::array(); for (const auto & tc : parsed_msg.tool_calls) { @@ -2966,9 +2967,9 @@ static json format_final_response_oaicompat(const json& request, json result, co {"id", tc.id} }); } - + bool has_tool_calls = !tool_calls.empty(); - + // Use cleaned content from parser (following original llama.cpp pattern) if (has_tool_calls) { content = parsed_msg.content; // Parser already cleaned the content @@ -3050,14 +3051,14 @@ static std::vector format_partial_response_oaicompat(server_task_result ta // Use generated_text (complete content) for finish_reason logic, not content (empty in streaming) std::string generated_text = json_value(result, "generated_text", std::string("")); ik_chat_msg final_msg = parse_chat_message_incremental(generated_text, false, modelname); - + // Debug logging LOG_INFO("DEBUG: Streaming finish_reason check", { {"generated_text", generated_text}, - {"model_name", modelname}, + {"model_name", modelname}, {"tool_calls_count", final_msg.tool_calls.size()} }); - + finish_reason = final_msg.tool_calls.empty() ? "stop" : "tool_calls"; } @@ -3065,18 +3066,18 @@ static std::vector format_partial_response_oaicompat(server_task_result ta // Follow original llama.cpp pattern: Always process diffs and add final chunk std::vector streaming_chunks; - + // Extract diffs from task result (populated by send_partial_response) // Following original llama.cpp pattern where diffs are stored in task result std::vector diffs; - + if (result.contains("oaicompat_msg_diffs") && result["oaicompat_msg_diffs"].is_array()) { for (const auto & diff_json : result["oaicompat_msg_diffs"]) { ik_chat_msg_diff diff; - + // Extract content delta diff.content_delta = diff_json.value("content_delta", ""); - + // Extract tool call data if (diff_json.contains("tool_call_index")) { diff.tool_call_index = diff_json["tool_call_index"]; @@ -3089,13 +3090,13 @@ static std::vector format_partial_response_oaicompat(server_task_result ta } else { diff.tool_call_index = std::string::npos; } - + diffs.push_back(diff); } } - + streaming_chunks = generate_streaming_chunks(diffs, completion_id, modelname); - + // Always add final chunk (like original llama.cpp) if (!finish_reason.empty()) { json finish_chunk = { @@ -3109,7 +3110,7 @@ static std::vector format_partial_response_oaicompat(server_task_result ta }; streaming_chunks.push_back(finish_chunk); } - + // Return streaming chunks (could be just final chunk if no diffs) if (!streaming_chunks.empty()) { return streaming_chunks; @@ -3275,7 +3276,7 @@ int main(int argc, char ** argv) { // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0; - + // struct that contains llama context and inference server_context ctx_server; From 5c96a7f619d89bc88773b8e0614ded77d2bbcd14 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 03:31:53 +0000 Subject: [PATCH 04/27] fix compile errors in speculative.cpp --- common/speculative.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 843bd1ddb..aa7592b5d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -15,7 +15,7 @@ struct common_speculative { struct common_sampler * smpl; llama_batch batch; - llama_tokens prompt; + std::vector prompt; }; struct common_speculative * common_speculative_init( @@ -84,13 +84,13 @@ bool common_speculative_are_compatible( const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); - LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + LLAMA_LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); const bool vocab_type_dft = llama_vocab_type(vocab_dft); - LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + LLAMA_LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + LLAMA_LOG_ERR("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } @@ -99,9 +99,9 @@ bool common_speculative_are_compatible( llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { - LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); - LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); - LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + LLAMA_LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LLAMA_LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); + LLAMA_LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); return false; } @@ -112,7 +112,7 @@ bool common_speculative_are_compatible( const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + LLAMA_LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; @@ -122,7 +122,7 @@ bool common_speculative_are_compatible( const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + LLAMA_LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " "token %d content differs - target '%s', draft '%s'\n", __func__, i, common_token_to_piece(ctx_tgt, i).c_str(), common_token_to_piece(ctx_dft, i).c_str()); @@ -169,7 +169,7 @@ llama_tokens common_speculative_gen_draft( } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + LLAMA_LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); llama_tokens result; result.reserve(params.n_draft); @@ -211,7 +211,7 @@ llama_tokens common_speculative_gen_draft( common_batch_clear(batch); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + //LLAMA_LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); prompt.push_back(prompt_tgt[i]); @@ -219,21 +219,21 @@ llama_tokens common_speculative_gen_draft( // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + //LLAMA_LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); } const llama_pos n_past = prompt.size(); - LOG_DBG("%s: n_past = %d\n", __func__, n_past); + LLAMA_LOG_DBG("%s: n_past = %d\n", __func__, n_past); common_batch_clear(batch); common_batch_add (batch, id_last, n_past, { 0 }, true); prompt.push_back(id_last); - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + //LLAMA_LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); llama_decode(ctx, batch); @@ -248,7 +248,7 @@ llama_tokens common_speculative_gen_draft( const auto * cur_p = common_sampler_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + LLAMA_LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); } From 99c1ef3c0107e5dcc2343f10e344ab072956f6e2 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 03:39:59 +0000 Subject: [PATCH 05/27] add llama_sampling_sample_and_accept_n to sampling --- common/sampling.cpp | 31 +++++++++++++++++++++++++++++++ common/sampling.h | 19 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index 08a19b457..bd9156264 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -506,3 +506,34 @@ void llama_sampling_accept( llama_sampler_dry_accept(ctx_sampling->smpl, id); } } + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + + llama_sampling_accept(gsmpl, ctx, id, true); + + result.push_back(id); + } + + return result; +} + diff --git a/common/sampling.h b/common/sampling.h index 1d5bf0b9e..2517daee6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -176,3 +176,22 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); + From 642b70a64b63a4f89dd454a9d73e2af76fbe5bfa Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 04:00:02 +0000 Subject: [PATCH 06/27] finish porting speculative decoding in server --- common/sampling.cpp | 9 +++---- common/sampling.h | 19 ++------------- examples/server/server.cpp | 50 ++++++++++++++++++-------------------- 3 files changed, 28 insertions(+), 50 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index bd9156264..7d460b579 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -507,15 +507,12 @@ void llama_sampling_accept( } } -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { - GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); - +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector result; - result.reserve(idxs.size()); size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); llama_sampling_accept(gsmpl, ctx, id, true); @@ -527,7 +524,7 @@ std::vector llama_sampling_sample_and_accept_n(struct llama_samplin } if (i == draft.size()) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); llama_sampling_accept(gsmpl, ctx, id, true); diff --git a/common/sampling.h b/common/sampling.h index 2517daee6..405f5a638 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -177,21 +177,6 @@ void llama_sampling_accept( llama_token id, bool apply_grammar); -// generalized version of common_sampler_sample -// -// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match -// if the sampler disagrees at some point, we stop and return the accepted tokens up to now -// -// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); -// -// is equivalent to -// -// common_sampler_sample(gsmpl, ctx, idx); -// common_sampler_accept(gsmpl, token, true); -// -// requires: idxs.size() == draft.size() + 1 -// -// returns at least 1 token, up to idxs.size() -// -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); +// returns at least 1 token, up to draft.size() +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d724bcf09..ad934137d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "speculative.h" +#include "sampling.h" #include "json-schema-to-grammar.h" #include "llama.h" #include "grammar-parser.h" @@ -819,7 +820,8 @@ struct server_context { bool add_bos_token = true; // For speculative decoding - llama_init_result model_dft_owned; + llama_model * model_draft = nullptr; + llama_context * ctx_draft = nullptr; llama_context_params cparams_dft; int32_t n_ctx; // total context for all clients / slots @@ -853,6 +855,16 @@ struct server_context { model = nullptr; } + // Free draft model and context if they exist + if (ctx_draft) { + llama_free(ctx_draft); + ctx_draft = nullptr; + } + if (model_draft) { + llama_free_model(model_draft); + model_draft = nullptr; + } + // Clear any sampling context for (server_slot & slot : slots) { if (slot.ctx_sampling != nullptr) { @@ -917,29 +929,13 @@ struct server_context { return false; } - // Store the draft context initialization parameters for later use - cparams_dft = llama_context_default_params(); - cparams_dft.n_ctx = params_dft.n_ctx; - cparams_dft.n_batch = cparams_dft.n_ctx; - cparams_dft.n_ubatch = params_dft.n_ubatch; - cparams_dft.freq_base = params_dft.rope_freq_base; - cparams_dft.freq_scale = params_dft.rope_freq_scale; - cparams_dft.yarn_ext_factor = params_dft.yarn_ext_factor; - cparams_dft.yarn_attn_factor = params_dft.yarn_attn_factor; - cparams_dft.yarn_beta_fast = params_dft.yarn_beta_fast; - cparams_dft.yarn_beta_slow = params_dft.yarn_beta_slow; - cparams_dft.yarn_orig_ctx = params_dft.yarn_orig_ctx; - cparams_dft.clip_kqv = params_dft.clip_kqv; - cparams_dft.pooling_type = params_dft.pooling_type; - cparams_dft.defrag_thold = params_dft.defrag_thold; - cparams_dft.type_k = params_dft.type_k; - cparams_dft.type_v = params_dft.type_v; - cparams_dft.logits_all = false; - cparams_dft.embedding = false; - cparams_dft.offload_kqv = params_dft.offload_kqv; - - // Keep the draft model alive - model_dft_owned = llama_init_dft; + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); + + cparams_dft = llama_context_params_from_gpt_params(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + model_draft = llama_init_dft.model; + ctx_draft = llama_init_dft.context; } return true; @@ -993,10 +989,10 @@ struct server_context { slot.sparams = params.sparams; // Initialize speculative decoding if a draft model is loaded - if (model_dft_owned.context) { + if (ctx_draft) { slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); - slot.ctx_dft = llama_init_from_model(model_dft_owned.model, cparams_dft); + slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); if (slot.ctx_dft == nullptr) { LOG_ERROR("failed to create draft context", {}); return; @@ -2906,7 +2902,7 @@ struct server_context { result.tok = ids[i]; result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special); - result.prob = 1.0f; // set later + // result.prob = 1.0f; // set later if (!process_token(result, slot)) { // release slot because of stop condition From 422af9eeca4fecadae235729bd4900922d80bfb6 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 04:39:43 +0000 Subject: [PATCH 07/27] port functions from common/speculative, common/sampling --- common/CMakeLists.txt | 1 + common/sampling.cpp | 8 ++- common/sampling.h | 5 ++ common/speculative.cpp | 137 ++++++++++++++++++++--------------------- common/speculative.h | 19 +++--- 5 files changed, 89 insertions(+), 81 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 789154e83..2c7e65f6e 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -74,6 +74,7 @@ add_library(${TARGET} STATIC train.cpp ngram-cache.h ngram-cache.cpp + speculative.cpp ) if (BUILD_SHARED_LIBS) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7d460b579..9c5580e88 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl( cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + ctx_sampling->cur_p = { cur.data(), cur.size(), false }; + + llama_token_data_array & cur_p = ctx_sampling->cur_p; // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; @@ -507,6 +509,10 @@ void llama_sampling_accept( } } +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) { + return &ctx_sampling->cur_p; +} + std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector result; diff --git a/common/sampling.h b/common/sampling.h index 405f5a638..d209a59f6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -101,6 +101,8 @@ struct llama_sampling_context { size_t n_valid; // Number of correct top tokens with correct probabilities. + llama_token_data_array cur_p; // current candidates + std::mt19937 rng; }; @@ -178,5 +180,8 @@ void llama_sampling_accept( bool apply_grammar); // returns at least 1 token, up to draft.size() +// access the internal list of current candidate tokens +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling); + std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); diff --git a/common/speculative.cpp b/common/speculative.cpp index aa7592b5d..ae326be4d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,8 +1,8 @@ #include "speculative.h" -#include "log.h" #include "common.h" #include "sampling.h" +#include "llama-impl.h" #include #include @@ -10,17 +10,17 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct common_speculative { +struct llama_speculative { struct llama_context * ctx; - struct common_sampler * smpl; + struct llama_sampling_context * smpl; llama_batch batch; std::vector prompt; }; -struct common_speculative * common_speculative_init( +struct llama_speculative * llama_speculative_init( struct llama_context * ctx_dft) { - auto * result = new common_speculative { + auto * result = new llama_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), @@ -30,7 +30,7 @@ struct common_speculative * common_speculative_init( // TODO: optimize or pass from outside? #if 0 { - common_params_sampling params; + llama_sampling_params params; params.no_perf = false; params.top_k = 40; @@ -42,90 +42,87 @@ struct common_speculative * common_speculative_init( COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + result->smpl = llama_sampler_init(llama_get_model(ctx_dft), params); } #else { - common_params_sampling params; - params.no_perf = false; - + llama_sampling_params params; params.top_k = 10; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, + params.samplers_sequence = { + llama_sampler_type::TOP_K, }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + const auto *model_dft = llama_get_model(ctx_dft); + result->smpl = llama_sampling_init(llama_get_model_vocab(model_dft), params); } #endif return result; } -void common_speculative_free(struct common_speculative * spec) { +void llama_speculative_free(struct llama_speculative * spec) { if (spec == nullptr) { return; } - common_sampler_free(spec->smpl); + llama_sampling_free(spec->smpl); llama_batch_free(spec->batch); delete spec; } -bool common_speculative_are_compatible( +bool llama_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft) { const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); - const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + const struct llama_vocab * vocab_tgt = llama_get_model_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_get_model_vocab(model_dft); - const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); - LLAMA_LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LLAMA_LOG_INFO("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); - const bool vocab_type_dft = llama_vocab_type(vocab_dft); - LLAMA_LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + const bool vocab_type_dft = llama_vocab_type(model_dft); + LLAMA_LOG_INFO("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - LLAMA_LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + LLAMA_LOG_ERROR("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } - if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || - llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || - llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || - llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { - LLAMA_LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); - LLAMA_LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); - LLAMA_LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { + LLAMA_LOG_ERROR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LLAMA_LOG_ERROR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); + LLAMA_LOG_ERROR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); return false; } { - const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); - const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); - const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + const int model_diff = std::abs(n_vocab_tgt - n_vocab_dft); - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LLAMA_LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + if (model_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LLAMA_LOG_ERROR("%s: draft model vocab must closely match target model to use speculation but " "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + __func__, n_vocab_tgt, n_vocab_dft, model_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); - const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LLAMA_LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + LLAMA_LOG_ERROR("%s: draft vocab vocab must match target vocab to use speculation but " "token %d content differs - target '%s', draft '%s'\n", __func__, i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); + llama_token_to_piece(ctx_tgt, i).c_str(), + llama_token_to_piece(ctx_dft, i).c_str()); return false; } } @@ -134,18 +131,16 @@ bool common_speculative_are_compatible( return true; } -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt_tgt, +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, + const std::vector & prompt_tgt, llama_token id_last) { auto & batch = spec->batch; auto & ctx = spec->ctx; auto & smpl = spec->smpl; auto & prompt = spec->prompt; - auto * mem = llama_get_memory(ctx); - int reuse_i = 0; int reuse_n = 0; @@ -169,13 +164,13 @@ llama_tokens common_speculative_gen_draft( } } - LLAMA_LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); - llama_tokens result; + std::vector result; result.reserve(params.n_draft); if (reuse_n == 0) { - llama_memory_clear(mem, false); + llama_kv_cache_clear(ctx, false); prompt.clear(); } else { @@ -194,68 +189,68 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_memory_seq_rm (mem, 0, 0, reuse_i); - llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i); } if (reuse_n < (int) prompt.size()) { - llama_memory_seq_rm (mem, 0, reuse_n, -1); + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); prompt.erase(prompt.begin() + reuse_n, prompt.end()); } } // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); + llama_batch_clear(batch); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { - //LLAMA_LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + //LLAMA_LOG_INFO("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); prompt.push_back(prompt_tgt[i]); } // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { - //LLAMA_LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + //LLAMA_LOG_INFO("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); } const llama_pos n_past = prompt.size(); - LLAMA_LOG_DBG("%s: n_past = %d\n", __func__, n_past); + LLAMA_LOG_INFO("%s: n_past = %d\n", __func__, n_past); - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); + llama_batch_clear(batch); + llama_batch_add (batch, id_last, n_past, { 0 }, true); prompt.push_back(id_last); - //LLAMA_LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + //LLAMA_LOG_INFO("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); llama_decode(ctx, batch); - common_sampler_reset(smpl); + llama_sampling_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { - common_batch_clear(batch); + llama_batch_clear(batch); - common_sampler_sample(smpl, ctx, 0, true); + llama_sampling_sample(smpl, ctx, 0, true); - const auto * cur_p = common_sampler_get_candidates(smpl); + const auto * cur_p = llama_sampling_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LLAMA_LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + LLAMA_LOG_INFO(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - common_sampler_accept(smpl, id, true); + llama_sampling_accept(smpl, ctx, id, true); result.push_back(id); @@ -268,7 +263,7 @@ llama_tokens common_speculative_gen_draft( break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + llama_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model llama_decode(ctx, batch); diff --git a/common/speculative.h b/common/speculative.h index 75f2e3112..faa6ee542 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -1,28 +1,29 @@ #pragma once #include "llama.h" -#include "common.h" -struct common_speculative; +#include -struct common_speculative_params { +struct llama_speculative; + +struct llama_speculative_params { int n_draft = 16; // max drafted tokens int n_reuse = 256; float p_min = 0.75f; // min probability required to accept a token in the draft }; -struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); +struct llama_speculative * llama_speculative_init(struct llama_context * ctx_dft); -void common_speculative_free(struct common_speculative * spec); +void llama_speculative_free(struct llama_speculative * spec); -bool common_speculative_are_compatible( +bool llama_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft); // sample up to n_draft tokens and add them to the batch using the draft model -std::vector common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, const std::vector & prompt, llama_token id_last); From 368c4647cf1edf44bdec212c7f83e35893101919 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 04:40:43 +0000 Subject: [PATCH 08/27] remove arg --- common/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index ae326be4d..345311caa 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -170,7 +170,7 @@ std::vector llama_speculative_gen_draft( result.reserve(params.n_draft); if (reuse_n == 0) { - llama_kv_cache_clear(ctx, false); + llama_kv_cache_clear(ctx); prompt.clear(); } else { From 8dbe1d639d7eaf972ad6e620574b22f9350eac7c Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 04:43:17 +0000 Subject: [PATCH 09/27] fix function names --- examples/server/server.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ad934137d..ee2ed3366 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -303,7 +303,7 @@ struct server_slot { int32_t ga_w = 512; // group-attention width // speculative decoding - struct common_speculative * spec = nullptr; + struct llama_speculative * spec = nullptr; llama_context * ctx_dft = nullptr; llama_batch batch_spec = {}; @@ -874,7 +874,7 @@ struct server_context { llama_free(slot.ctx_dft); } if (slot.spec) { - common_speculative_free(slot.spec); + llama_speculative_free(slot.spec); } llama_batch_free(slot.batch_spec); } @@ -924,7 +924,7 @@ struct server_context { return false; } - if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { + if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) { LOG_ERROR("the draft model is not compatible with the target model", {}); return false; } @@ -998,7 +998,7 @@ struct server_context { return; } - slot.spec = common_speculative_init(slot.ctx_dft); + slot.spec = llama_speculative_init(slot.ctx_dft); if (slot.spec == nullptr) { LOG_ERROR("failed to create speculator", {}); return; @@ -2847,13 +2847,13 @@ struct server_context { llama_token id = slot.sampled; - struct common_speculative_params params_spec; + struct llama_speculative_params params_spec; params_spec.n_draft = n_draft_max; params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; const std::vector & cached_text_tokens = slot.cache_tokens; - std::vector draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + std::vector draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { From 1c771020aa74983fba822d749a24220a36fb846c Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 05:17:17 +0000 Subject: [PATCH 10/27] init params_dft to none --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ee2ed3366..401023cf0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -908,7 +908,7 @@ struct server_context { if (!params.model_draft.empty()) { LOG_INFO("loading draft model", {{"model", params.model_draft}}); - gpt_params params_dft = params; + gpt_params params_dft; params_dft.model = params.model_draft; params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_gpu_layers_draft; params_dft.n_gpu_layers = params.n_gpu_layers_draft; From d5924781958475c5ec61d8e64cb8b1d98f6c240d Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 05:26:30 +0000 Subject: [PATCH 11/27] correct value for n_ctx --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 401023cf0..18ae25c7e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -910,7 +910,7 @@ struct server_context { gpt_params params_dft; params_dft.model = params.model_draft; - params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_gpu_layers_draft; + params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx; // TODO: add params_base.speculative.n_ctx params_dft.n_gpu_layers = params.n_gpu_layers_draft; params_dft.n_parallel = 1; params_dft.cache_type_k = params.cache_type_k; From fbd5dfd8660ced64a05a23fe3d5526ded635eb4b Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 07:09:01 +0000 Subject: [PATCH 12/27] prefix kv cache tensors with model name to avoid conflict --- src/llama.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 27647c9d2..c5134bfb5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3581,16 +3581,16 @@ static bool llama_kv_cache_init( //LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_k_l%d", i); + ggml_format_name(kv, "%s.cache_k_l%d", model.name.c_str(), i); cache.k_l.push_back(kv); } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_k_l%d", i); + ggml_format_name(kv, "%s.cache_k_l%d", model.name.c_str(), i); cache.k_l.push_back(kv); if (cparams.mla_attn == 1) { ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_v_l%d", i); + ggml_format_name(kvt, "%s.cache_v_l%d", model.name.c_str(), i); cache.v_l.push_back(kvt); } } @@ -3599,8 +3599,8 @@ static bool llama_kv_cache_init( else { k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); + ggml_format_name(k, "%s.cache_k_l%d", model.name.c_str(), i); + ggml_format_name(v, "%s.cache_v_l%d", model.name.c_str(), i); cache.k_l.push_back(k); cache.v_l.push_back(v); } @@ -7471,7 +7471,7 @@ static bool llm_load_tensors( // output model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { @@ -7480,7 +7480,7 @@ static bool llm_load_tensors( for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); - ggml_context * ctx_split = ctx_for_layer_split(i); + ggml_context * ctx_split = ctx_for_layer_split(i); auto & layer = model.layers[i]; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); @@ -7492,7 +7492,7 @@ static bool llm_load_tensors( if (n_head_kv == 0 && n_head > 0) { // linear attention for DeciLMCausalModel - layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); } else if (n_head_kv > 0) { @@ -7505,8 +7505,8 @@ static bool llm_load_tensors( } // optional bias tensors - - + + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); From d85bc15f9bf7cfa027d7f18ac0c42c993597e575 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 07:15:51 +0000 Subject: [PATCH 13/27] fix call arguments --- common/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 345311caa..fbb9c1c49 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -238,7 +238,7 @@ std::vector llama_speculative_gen_draft( for (int i = 0; i < params.n_draft; ++i) { llama_batch_clear(batch); - llama_sampling_sample(smpl, ctx, 0, true); + llama_sampling_sample(smpl, ctx, nullptr, 0); const auto * cur_p = llama_sampling_get_candidates(smpl); From 388814482a70dc82e728501b81e6ace3e83eb231 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 07:42:15 +0000 Subject: [PATCH 14/27] fix spec decoding args --- examples/server/server.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18ae25c7e..2451226b1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -235,7 +235,7 @@ struct slot_params { // speculative decoding parameters struct { - int n_max = 0; // max drafted tokens + int n_max = 16; // max drafted tokens int n_min = 0; // min drafted tokens to accept float p_min = 0.75f; // min probability required to accept a token in the draft } speculative; @@ -1197,9 +1197,9 @@ struct server_context { slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); // speculative decoding parameters - slot.params.speculative.n_max = json_value(data, "speculative.n_max", 0); - slot.params.speculative.n_min = json_value(data, "speculative.n_min", 0); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", 0.75f); + slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min); // Clamp speculative parameters slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); From 1959998697c08fdf8e36a2fd2b7d65f3a5d2c68e Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 08:03:18 +0000 Subject: [PATCH 15/27] correct slot.id --- examples/server/server.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2451226b1..de890bc0b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2870,10 +2870,10 @@ struct server_context { // construct the speculation batch llama_batch_clear(slot.batch_spec); - llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true); + llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true); + llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); } LOG_VERBOSE("decoding speculative batch", { @@ -2895,7 +2895,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); - llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; From 6bcd795ec3fe2a07a657fdfd112c6d0f67e24462 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 08:14:25 +0000 Subject: [PATCH 16/27] use n_max --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index de890bc0b..35a695bbd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -990,7 +990,7 @@ struct server_context { // Initialize speculative decoding if a draft model is loaded if (ctx_draft) { - slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1); + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); if (slot.ctx_dft == nullptr) { From 694af02ddd47002e61b2c0059369a608556448f5 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 08:26:29 +0000 Subject: [PATCH 17/27] port the rest of sampling funcs --- common/sampling.cpp | 20 ++++++++++++++++---- common/sampling.h | 3 ++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c5580e88..24fc79cf8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -513,12 +513,24 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_con return &ctx_sampling->cur_p; } -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { +std::vector llama_sampling_sampler_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return llama_sampling_sample_and_accept_n(gsmpl, ctx, idxs, draft); +} + +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + std::vector result; + result.reserve(idxs.size()); size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); llama_sampling_accept(gsmpl, ctx, id, true); @@ -530,9 +542,9 @@ std::vector llama_sampling_sample_and_accept_n(struct llama_samplin } if (i == draft.size()) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, i); + const llama_token id = llama_sampling_sample(gsmpl, ctx, idxs[i], grammar_first); - llama_sampling_accept(gsmpl, ctx, id, true); + llama_sampling_accept(gsmpl, id, true); result.push_back(id); } diff --git a/common/sampling.h b/common/sampling.h index d209a59f6..dd39556cb 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -183,5 +183,6 @@ void llama_sampling_accept( // access the internal list of current candidate tokens llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling); -std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); +std::vector llama_sampling_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & draft); +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); From 4a41cfd0837a3950697c507b3f073be61e881ba6 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 08:27:00 +0000 Subject: [PATCH 18/27] fix func arguments --- common/sampling.cpp | 6 +++--- common/sampling.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 24fc79cf8..526a47ebc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -513,7 +513,7 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_con return &ctx_sampling->cur_p; } -std::vector llama_sampling_sampler_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; @@ -542,9 +542,9 @@ std::vector llama_sampling_sample_and_accept_n(struct llama_samplin } if (i == draft.size()) { - const llama_token id = llama_sampling_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]); - llama_sampling_accept(gsmpl, id, true); + llama_sampling_accept(gsmpl, ctx, id, true); result.push_back(id); } diff --git a/common/sampling.h b/common/sampling.h index dd39556cb..27401145c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -183,6 +183,6 @@ void llama_sampling_accept( // access the internal list of current candidate tokens llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling); -std::vector llama_sampling_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & draft); +std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft); From e938d9f668233bc368385f3f04c8cede5a4008c7 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Fri, 25 Jul 2025 08:52:45 +0000 Subject: [PATCH 19/27] slot.id starts at 1? --- examples/server/server.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 35a695bbd..d72e2d21f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2870,10 +2870,10 @@ struct server_context { // construct the speculation batch llama_batch_clear(slot.batch_spec); - llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id }, true); + llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true); } LOG_VERBOSE("decoding speculative batch", { @@ -2895,7 +2895,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; From 7f5e298c2ab5cadee3f2f1833d6546a1a58efe8a Mon Sep 17 00:00:00 2001 From: "T. M." Date: Sun, 27 Jul 2025 05:08:03 +0000 Subject: [PATCH 20/27] Revert "prefix kv cache tensors with model name to avoid conflict" This reverts commit fbd5dfd8660ced64a05a23fe3d5526ded635eb4b. --- src/llama.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c5134bfb5..27647c9d2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3581,16 +3581,16 @@ static bool llama_kv_cache_init( //LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "%s.cache_k_l%d", model.name.c_str(), i); + ggml_format_name(kv, "cache_k_l%d", i); cache.k_l.push_back(kv); } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "%s.cache_k_l%d", model.name.c_str(), i); + ggml_format_name(kv, "cache_k_l%d", i); cache.k_l.push_back(kv); if (cparams.mla_attn == 1) { ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "%s.cache_v_l%d", model.name.c_str(), i); + ggml_format_name(kvt, "cache_v_l%d", i); cache.v_l.push_back(kvt); } } @@ -3599,8 +3599,8 @@ static bool llama_kv_cache_init( else { k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "%s.cache_k_l%d", model.name.c_str(), i); - ggml_format_name(v, "%s.cache_v_l%d", model.name.c_str(), i); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); } @@ -7471,7 +7471,7 @@ static bool llm_load_tensors( // output model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { @@ -7480,7 +7480,7 @@ static bool llm_load_tensors( for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); - ggml_context * ctx_split = ctx_for_layer_split(i); + ggml_context * ctx_split = ctx_for_layer_split(i); auto & layer = model.layers[i]; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); @@ -7492,7 +7492,7 @@ static bool llm_load_tensors( if (n_head_kv == 0 && n_head > 0) { // linear attention for DeciLMCausalModel - layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); } else if (n_head_kv > 0) { @@ -7505,8 +7505,8 @@ static bool llm_load_tensors( } // optional bias tensors - - + + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); From 946fa65c573e1622f599b4a56c1dc28d5a89a4ad Mon Sep 17 00:00:00 2001 From: "T. M." Date: Sun, 3 Aug 2025 01:48:05 +0000 Subject: [PATCH 21/27] disable draft logging --- common/speculative.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index fbb9c1c49..558030a3f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -242,10 +242,10 @@ std::vector llama_speculative_gen_draft( const auto * cur_p = llama_sampling_get_candidates(smpl); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LLAMA_LOG_INFO(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); - } + // for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + // LLAMA_LOG_INFO(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + // k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); + // } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; From 909a80de5e1156c87910c456501c67b6b1066c1d Mon Sep 17 00:00:00 2001 From: "T. M." Date: Sun, 3 Aug 2025 01:51:32 +0000 Subject: [PATCH 22/27] disable logging in speculative.cpp in mainline, these would be LOG_DEBUG, but since ik_llama doesnt support it, logging is disabled entirely --- common/speculative.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 558030a3f..d70a278ad 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -164,7 +164,7 @@ std::vector llama_speculative_gen_draft( } } - LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + // LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); std::vector result; result.reserve(params.n_draft); @@ -221,7 +221,7 @@ std::vector llama_speculative_gen_draft( const llama_pos n_past = prompt.size(); - LLAMA_LOG_INFO("%s: n_past = %d\n", __func__, n_past); + // LLAMA_LOG_INFO("%s: n_past = %d\n", __func__, n_past); llama_batch_clear(batch); llama_batch_add (batch, id_last, n_past, { 0 }, true); From 8ff1e03cd6abe8f9faf76a181d6bdf1d9a442029 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Thu, 7 Aug 2025 07:21:27 +0000 Subject: [PATCH 23/27] add more draft model parameters --- common/common.cpp | 19 ++++++++++++++++++- common/common.h | 3 +++ examples/server/server.cpp | 6 +++--- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1801da039..768f9bbc2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -486,6 +486,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_ctx = std::stoi(argv[i]); return true; } + if (arg == "-cd" || arg == "--ctx-size-draft") { + CHECK_ARG + params.n_ctx_draft = std::stoi(argv[i]); + return true; + } if (arg == "--grp-attn-n" || arg == "-gan") { CHECK_ARG params.grp_attn_n = std::stoi(argv[i]); @@ -915,6 +920,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cache_type_v = argv[++i]; return true; } + if (arg == "-ctkd" || arg == "--cache-type-k-draft") { + params.cache_type_k_draft = argv[++i]; + return true; + } + if (arg == "-ctvd" || arg == "--cache-type-v-draft") { + params.cache_type_v_draft = argv[++i]; + return true; + } if (arg == "-mli" || arg == "--multiline-input") { params.multiline_input = true; return true; @@ -1648,6 +1661,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "path to dynamic lookup cache to use for lookup decoding (updated by generation)" }); options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); + options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); @@ -1758,6 +1772,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" }); options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() }); options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() }); + options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" }); + options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" }); options.push_back({ "perplexity" }); options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" }); @@ -2505,7 +2521,8 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); - cparams.n_ctx = params.n_ctx; + // Use draft context size if specified and we have a draft model, otherwise use regular context size + cparams.n_ctx = params.model_draft.empty() ? params.n_ctx : (params.n_ctx_draft > 0 ? params.n_ctx_draft : params.n_ctx); cparams.n_seq_max = params.n_parallel; cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; diff --git a/common/common.h b/common/common.h index 99048cd2a..f3daff89f 100644 --- a/common/common.h +++ b/common/common.h @@ -83,6 +83,7 @@ struct gpt_params { int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size + int32_t n_ctx_draft = 0; // context size for draft model int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt @@ -207,6 +208,8 @@ struct gpt_params { std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V + std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model + std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d0ee1731e..c43b719bd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -910,11 +910,11 @@ struct server_context { gpt_params params_dft; params_dft.model = params.model_draft; - params_dft.n_ctx = params.n_gpu_layers_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx; // TODO: add params_base.speculative.n_ctx + params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft; params_dft.n_gpu_layers = params.n_gpu_layers_draft; params_dft.n_parallel = 1; - params_dft.cache_type_k = params.cache_type_k; - params_dft.cache_type_v = params.cache_type_v; + params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; + params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); From a694d7d04e1b41106d14be318bcbb948eaa8629a Mon Sep 17 00:00:00 2001 From: "T. M." Date: Thu, 7 Aug 2025 07:23:12 +0000 Subject: [PATCH 24/27] fix --- common/common.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 768f9bbc2..a8cb3e5c6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -711,7 +711,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } } return true; - } + } if (arg == "--cfg-negative-prompt") { CHECK_ARG sparams.cfg_negative_prompt = argv[i]; @@ -1065,7 +1065,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa size_t pos = 0; while ((pos = servers.find(",")) != std::string::npos) { std::string server = servers.substr(0, pos); - ggml_backend_rpc_buffer_type(server.c_str()); + ggml_backend_rpc_buffer_type(server.c_str()); servers.erase(0, pos + 1); } ggml_backend_rpc_buffer_type(servers.c_str()); @@ -1997,7 +1997,7 @@ std::string string_join(const std::vector & strs, const std::string if (strs.empty()) { return ""; } - + std::ostringstream oss; for (size_t i = 0; i < strs.size(); ++i) { if (i > 0) { @@ -2521,8 +2521,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); - // Use draft context size if specified and we have a draft model, otherwise use regular context size - cparams.n_ctx = params.model_draft.empty() ? params.n_ctx : (params.n_ctx_draft > 0 ? params.n_ctx_draft : params.n_ctx); + cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; cparams.n_batch = params.n_batch; cparams.n_ubatch = params.n_ubatch; From 8effd8c6e555ed42f6c9bd2e351912f1c907b732 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Thu, 7 Aug 2025 15:20:44 +0000 Subject: [PATCH 25/27] pass flash_attn --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c43b719bd..73b01275f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -915,6 +915,7 @@ struct server_context { params_dft.n_parallel = 1; params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; + params_dft.flash_attn = params.flash_attn; llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); From ad8d26f28eef6167b97c36cdf959cc002a0ff0c0 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Thu, 7 Aug 2025 15:36:37 +0000 Subject: [PATCH 26/27] add speculative params for parity --- common/common.cpp | 17 +++++++++++++++-- common/common.h | 4 +++- examples/server/server.cpp | 4 ++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a8cb3e5c6..73e7ad379 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -751,11 +751,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_keep = std::stoi(argv[i]); return true; } - if (arg == "--draft") { + if (arg == "--draft" || arg == "--draft-max" || arg == "--draft-n") { CHECK_ARG params.n_draft = std::stoi(argv[i]); return true; } + if (arg == "--draft-min" || arg == "--draft-n-min") { + CHECK_ARG + params.n_draft_min = std::stoi(argv[i]); + return true; + } + if (arg == "--draft-p-min") { + CHECK_ARG + params.p_draft_min = std::stof(argv[i]); + return true; + } if (arg == "--chunks") { CHECK_ARG params.n_chunks = std::stoi(argv[i]); @@ -1653,7 +1663,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); options.push_back({ "speculative", "-tbd, --threads-batch-draft N", "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); - options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)" }); @@ -1854,6 +1863,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); + options.push_back({ "*", "--draft-max, --draft, --draft-n N", + "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); + options.push_back({ "*", "--draft-min, --draft-n-min N", "minimum number of draft tokens to use for speculative decoding" }); + options.push_back({ "*", "--draft-p-min P", "minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.p_draft_min }); options.push_back({ "retrieval" }); options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); diff --git a/common/common.h b/common/common.h index f3daff89f..3620d9ee5 100644 --- a/common/common.h +++ b/common/common.h @@ -87,7 +87,9 @@ struct gpt_params { int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding + float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 73b01275f..49a9bc660 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -989,6 +989,10 @@ struct server_context { slot.sparams = params.sparams; + slot.params.speculative.n_max = params.n_draft; + slot.params.speculative.n_min = params.n_draft_min; + slot.params.speculative.p_min = params.p_draft_min; + // Initialize speculative decoding if a draft model is loaded if (ctx_draft) { slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); From ed2a40aab6205c4822e572aaeb3b3c8c59221111 Mon Sep 17 00:00:00 2001 From: "T. M." Date: Thu, 7 Aug 2025 15:41:42 +0000 Subject: [PATCH 27/27] set speculative params in launch_slot_with_task instead --- examples/server/server.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 49a9bc660..ed93e7119 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -989,10 +989,6 @@ struct server_context { slot.sparams = params.sparams; - slot.params.speculative.n_max = params.n_draft; - slot.params.speculative.n_min = params.n_draft_min; - slot.params.speculative.p_min = params.p_draft_min; - // Initialize speculative decoding if a draft model is loaded if (ctx_draft) { slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); @@ -1202,9 +1198,9 @@ struct server_context { slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); // speculative decoding parameters - slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max); - slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min); + slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min); // Clamp speculative parameters slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);