Skip to content

Commit b9fdf20

Browse files
committed
add --spec-replace flag
1 parent 8419931 commit b9fdf20

File tree

6 files changed

+59
-1
lines changed

6 files changed

+59
-1
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3217,6 +3217,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32173217
params.speculative.model.path = value;
32183218
}
32193219
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
3220+
add_opt(common_arg(
3221+
{"--spec-replace"}, "TARGET", "DRAFT",
3222+
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
3223+
[](common_params & params, const std::string & tgt, const std::string & dft) {
3224+
params.speculative.replacements.push_back({ tgt, dft });
3225+
}
3226+
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
32203227
add_opt(common_arg(
32213228
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
32223229
string_format(

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ struct common_params_speculative {
198198
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
199199
float p_split = 0.1f; // speculative decoding split probability
200200
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
201+
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
201202

202203
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
203204
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

common/speculative.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <cstring>
99
#include <algorithm>
10+
#include <map>
1011

1112
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1213
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -19,6 +20,7 @@ struct common_speculative {
1920
llama_batch batch;
2021
llama_tokens prompt_dft;
2122
bool vocab_dft_compatible = true; // whether retokenization is needed
23+
std::map<std::string, std::string> tgt_dft_replacements = {};
2224
};
2325

2426
struct common_speculative * common_speculative_init(
@@ -144,6 +146,41 @@ bool common_speculative_are_compatible(
144146
return true;
145147
}
146148

149+
void common_speculative_add_replacement_tgt_dft(
150+
struct common_speculative * spec,
151+
const char *source, const char *dest) {
152+
spec->tgt_dft_replacements[source] = dest;
153+
}
154+
155+
static std::string replace_to_dft(
156+
struct common_speculative * spec,
157+
const std::string& input) {
158+
std::string result = input;
159+
for (const auto& pair : spec->tgt_dft_replacements) {
160+
size_t pos = result.find(pair.first);
161+
while (pos != std::string::npos) {
162+
result.replace(pos, pair.first.length(), pair.second);
163+
pos = result.find(pair.first, pos + pair.second.length());
164+
}
165+
}
166+
return result;
167+
}
168+
169+
static std::string replace_to_tgt(
170+
struct common_speculative * spec,
171+
const std::string& input) {
172+
std::string result = input;
173+
for (const auto& pair : spec->tgt_dft_replacements) {
174+
size_t pos = result.find(pair.second);
175+
while (pos != std::string::npos) {
176+
result.replace(pos, pair.second.length(), pair.first);
177+
pos = result.find(pair.second, pos + pair.first.length());
178+
}
179+
}
180+
return result;
181+
}
182+
183+
147184
llama_tokens common_speculative_gen_draft(
148185
struct common_speculative * spec,
149186
struct common_speculative_params params,
@@ -168,10 +205,11 @@ llama_tokens common_speculative_gen_draft(
168205

169206
std::string text;
170207
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, false);
208+
text = replace_to_dft(spec, text);
171209
LOG_DBG("main->draft detokenized string: '%s'\n", text.c_str());
172210
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, false);
173-
174211
text.clear();
212+
175213
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
176214
int32_t n_chars;
177215
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, false);
@@ -180,6 +218,7 @@ llama_tokens common_speculative_gen_draft(
180218
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, false);
181219
}
182220
text.resize(n_chars);
221+
text = replace_to_dft(spec, text);
183222
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
184223
id_last = common_tokenize(ctx_dft, text, false, false)[0];
185224
}
@@ -312,6 +351,7 @@ llama_tokens common_speculative_gen_draft(
312351

313352
if (!spec->vocab_dft_compatible) {
314353
std::string detokenized = common_detokenize(ctx_dft, result, false);
354+
detokenized = replace_to_tgt(spec, detokenized);
315355
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
316356
result = common_tokenize(ctx_tgt, detokenized, false, false);
317357
}

common/speculative.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ bool common_speculative_are_compatible(
2323
const struct llama_context * ctx_tgt,
2424
const struct llama_context * ctx_dft);
2525

26+
void common_speculative_add_replacement_tgt_dft(
27+
struct common_speculative * spec,
28+
const char *source, const char *dest);
29+
2630
// sample up to n_draft tokens and add them to the batch using the draft model
2731
llama_tokens common_speculative_gen_draft(
2832
struct common_speculative * spec,

examples/speculative-simple/speculative-simple.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ int main(int argc, char ** argv) {
127127
params_spec.p_min = p_min;
128128

129129
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
130+
for (auto &pair : params.speculative.replacements) {
131+
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
132+
}
130133

131134
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
132135

tools/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,6 +2079,9 @@ struct server_context {
20792079
SRV_ERR("%s", "failed to create speculator\n");
20802080
return;
20812081
}
2082+
for (auto &pair : params_base.speculative.replacements) {
2083+
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
2084+
}
20822085
}
20832086

20842087
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);

0 commit comments

Comments
 (0)