Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
de5ecab
server : integrate speculative decoding
g2mt Jul 25, 2025
98f6a48
server: Fix field names
g2mt Jul 25, 2025
80a0579
server: fix include, whitespace
g2mt Jul 25, 2025
5c96a7f
fix compile errors in speculative.cpp
g2mt Jul 25, 2025
99c1ef3
add llama_sampling_sample_and_accept_n to sampling
g2mt Jul 25, 2025
642b70a
finish porting speculative decoding in server
g2mt Jul 25, 2025
422af9e
port functions from common/speculative, common/sampling
g2mt Jul 25, 2025
368c464
remove arg
g2mt Jul 25, 2025
8dbe1d6
fix function names
g2mt Jul 25, 2025
1c77102
init params_dft to none
g2mt Jul 25, 2025
d592478
correct value for n_ctx
g2mt Jul 25, 2025
fbd5dfd
prefix kv cache tensors with model name to avoid conflict
g2mt Jul 25, 2025
d85bc15
fix call arguments
g2mt Jul 25, 2025
3888144
fix spec decoding args
g2mt Jul 25, 2025
1959998
correct slot.id
g2mt Jul 25, 2025
6bcd795
use n_max
g2mt Jul 25, 2025
694af02
port the rest of sampling funcs
g2mt Jul 25, 2025
4a41cfd
fix func arguments
g2mt Jul 25, 2025
e938d9f
slot.id starts at 1?
g2mt Jul 25, 2025
07e7cb3
Merge branch 'main' into speculative-port
g2mt Jul 27, 2025
7f5e298
Revert "prefix kv cache tensors with model name to avoid conflict"
g2mt Jul 27, 2025
c81e7b1
Merge remote-tracking branch 'fork/speculative-port' into speculative…
g2mt Jul 27, 2025
946fa65
disable draft logging
g2mt Aug 3, 2025
909a80d
disable logging in speculative.cpp
g2mt Aug 3, 2025
8ff1e03
add more draft model parameters
g2mt Aug 7, 2025
a694d7d
fix
g2mt Aug 7, 2025
8effd8c
pass flash_attn
g2mt Aug 7, 2025
ad8d26f
add speculative params for parity
g2mt Aug 7, 2025
ed2a40a
set speculative params in launch_slot_with_task instead
g2mt Aug 7, 2025
a2810b4
Merge branch 'main' into speculative-port
g2mt Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ add_library(${TARGET} STATIC
train.cpp
ngram-cache.h
ngram-cache.cpp
speculative.cpp
)

if (BUILD_SHARED_LIBS)
Expand Down
22 changes: 19 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_ctx = std::stoi(argv[i]);
return true;
}
if (arg == "-cd" || arg == "--ctx-size-draft") {
CHECK_ARG
params.n_ctx_draft = std::stoi(argv[i]);
return true;
}
if (arg == "--grp-attn-n" || arg == "-gan") {
CHECK_ARG
params.grp_attn_n = std::stoi(argv[i]);
Expand Down Expand Up @@ -706,7 +711,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
}
return true;
}
}
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
Expand Down Expand Up @@ -915,6 +920,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cache_type_v = argv[++i];
return true;
}
if (arg == "-ctkd" || arg == "--cache-type-k-draft") {
params.cache_type_k_draft = argv[++i];
return true;
}
if (arg == "-ctvd" || arg == "--cache-type-v-draft") {
params.cache_type_v_draft = argv[++i];
return true;
}
if (arg == "-mli" || arg == "--multiline-input") {
params.multiline_input = true;
return true;
Expand Down Expand Up @@ -1052,7 +1065,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
size_t pos = 0;
while ((pos = servers.find(",")) != std::string::npos) {
std::string server = servers.substr(0, pos);
ggml_backend_rpc_buffer_type(server.c_str());
ggml_backend_rpc_buffer_type(server.c_str());
servers.erase(0, pos + 1);
}
ggml_backend_rpc_buffer_type(servers.c_str());
Expand Down Expand Up @@ -1648,6 +1661,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"path to dynamic lookup cache to use for lookup decoding (updated by generation)" });

options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
options.push_back({ "*", "-cd, --ctx-size-draft N", "size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.n_ctx_draft });
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch });
options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
Expand Down Expand Up @@ -1758,6 +1772,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" });
options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
options.push_back({ "*", "-ctkd, --cache-type-k-draft TYPE", "KV cache data type for K for the draft model" });
options.push_back({ "*", "-ctvd, --cache-type-v-draft TYPE", "KV cache data type for V for the draft model" });

options.push_back({ "perplexity" });
options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" });
Expand Down Expand Up @@ -1981,7 +1997,7 @@ std::string string_join(const std::vector<std::string> & strs, const std::string
if (strs.empty()) {
return "";
}

std::ostringstream oss;
for (size_t i = 0; i < strs.size(); ++i) {
if (i > 0) {
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct gpt_params {
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 0; // context size
int32_t n_ctx_draft = 0; // context size for draft model
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
Expand Down Expand Up @@ -207,6 +208,8 @@ struct gpt_params {

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model
std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model

// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
Expand Down
48 changes: 47 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };
ctx_sampling->cur_p = { cur.data(), cur.size(), false };

llama_token_data_array & cur_p = ctx_sampling->cur_p;

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
Expand Down Expand Up @@ -506,3 +508,47 @@ void llama_sampling_accept(
llama_sampler_dry_accept(ctx_sampling->smpl, id);
}
}

llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) {
return &ctx_sampling->cur_p;
}

std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return llama_sampling_sample_and_accept_n(gsmpl, ctx, idxs, draft);
}

std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);

llama_sampling_accept(gsmpl, ctx, id, true);

result.push_back(id);

if (draft[i] != id) {
break;
}
}

if (i == draft.size()) {
const llama_token id = llama_sampling_sample(gsmpl, ctx, nullptr, idxs[i]);

llama_sampling_accept(gsmpl, ctx, id, true);

result.push_back(id);
}

return result;
}

10 changes: 10 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ struct llama_sampling_context {

size_t n_valid; // Number of correct top tokens with correct probabilities.

llama_token_data_array cur_p; // current candidates

std::mt19937 rng;
};

Expand Down Expand Up @@ -176,3 +178,11 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

// returns at least 1 token, up to draft.size()
// access the internal list of current candidate tokens
llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling);

std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);

std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);
Loading