77
88#include < cstring>
99#include < algorithm>
10+ #include < map>
1011
1112#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1213#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -19,6 +20,7 @@ struct common_speculative {
1920 llama_batch batch;
2021 llama_tokens prompt_dft;
2122 bool vocab_dft_compatible = true ; // whether retokenization is needed
23+ std::map<std::string, std::string> tgt_dft_replacements = {};
2224};
2325
2426struct common_speculative * common_speculative_init (
@@ -144,6 +146,41 @@ bool common_speculative_are_compatible(
144146 return true ;
145147}
146148
149+ void common_speculative_add_replacement_tgt_dft (
150+ struct common_speculative * spec,
151+ const char *source, const char *dest) {
152+ spec->tgt_dft_replacements [source] = dest;
153+ }
154+
155+ static std::string replace_to_dft (
156+ struct common_speculative * spec,
157+ const std::string& input) {
158+ std::string result = input;
159+ for (const auto & pair : spec->tgt_dft_replacements ) {
160+ size_t pos = result.find (pair.first );
161+ while (pos != std::string::npos) {
162+ result.replace (pos, pair.first .length (), pair.second );
163+ pos = result.find (pair.first , pos + pair.second .length ());
164+ }
165+ }
166+ return result;
167+ }
168+
169+ static std::string replace_to_tgt (
170+ struct common_speculative * spec,
171+ const std::string& input) {
172+ std::string result = input;
173+ for (const auto & pair : spec->tgt_dft_replacements ) {
174+ size_t pos = result.find (pair.second );
175+ while (pos != std::string::npos) {
176+ result.replace (pos, pair.second .length (), pair.first );
177+ pos = result.find (pair.second , pos + pair.first .length ());
178+ }
179+ }
180+ return result;
181+ }
182+
183+
147184llama_tokens common_speculative_gen_draft (
148185 struct common_speculative * spec,
149186 struct common_speculative_params params,
@@ -168,10 +205,11 @@ llama_tokens common_speculative_gen_draft(
168205
169206 std::string text;
170207 text = common_detokenize (ctx_tgt, prompt_tgt_main_model, false );
208+ text = replace_to_dft (spec, text);
171209 LOG_DBG (" main->draft detokenized string: '%s'\n " , text.c_str ());
172210 prompt_tgt_draft_model = common_tokenize (ctx_dft, text, false , false );
173-
174211 text.clear ();
212+
175213 const llama_vocab * vocab_tgt = llama_model_get_vocab (model_tgt);
176214 int32_t n_chars;
177215 n_chars = llama_detokenize (vocab_tgt, &id_last, 1 , &text[0 ], text.size (), false , false );
@@ -180,6 +218,7 @@ llama_tokens common_speculative_gen_draft(
180218 n_chars = llama_detokenize (vocab_tgt, &id_last, 1 , &text[0 ], text.size (), false , false );
181219 }
182220 text.resize (n_chars);
221+ text = replace_to_dft (spec, text);
183222 LOG_DBG (" main->draft detokenized id_last(%d): '%s'\n " , id_last, text.c_str ());
184223 id_last = common_tokenize (ctx_dft, text, false , false )[0 ];
185224 }
@@ -312,6 +351,7 @@ llama_tokens common_speculative_gen_draft(
312351
313352 if (!spec->vocab_dft_compatible ) {
314353 std::string detokenized = common_detokenize (ctx_dft, result, false );
354+ detokenized = replace_to_tgt (spec, detokenized);
315355 LOG_DBG (" draft->main detokenized string: '%s'\n " , detokenized.c_str ());
316356 result = common_tokenize (ctx_tgt, detokenized, false , false );
317357 }
0 commit comments