11#include " speculative.h"
22
3+ #include " ggml.h"
4+ #include " llama.h"
35#include " log.h"
46#include " common.h"
57#include " sampling.h"
68
79#include < cstring>
810#include < algorithm>
11+ #include < map>
912
1013#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1114#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
1215
1316struct common_speculative {
14- struct llama_context * ctx;
17+ struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
18+ struct llama_context * ctx_dft;
1519 struct common_sampler * smpl;
1620
1721 llama_batch batch;
18- llama_tokens prompt;
22+ llama_tokens prompt_dft;
23+ bool vocab_dft_compatible = true ; // whether retokenization is needed
24+ std::map<std::string, std::string> tgt_dft_replacements = {};
1925};
2026
2127struct common_speculative * common_speculative_init (
28+ struct llama_context * ctx_tgt,
2229 struct llama_context * ctx_dft) {
2330 auto * result = new common_speculative {
24- /* .ctx = */ ctx_dft,
25- /* .smpl = */ nullptr ,
26- /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
27- /* .prompt = */ {},
31+ /* .ctx_tgt = */ ctx_tgt,
32+ /* .ctx_dft = */ ctx_dft,
33+ /* .smpl = */ nullptr ,
34+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
35+ /* .prompt_dft = */ {},
36+ /* .vocab_dft_compatible = */ false ,
2837 };
2938
3039 // TODO: optimize or pass from outside?
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
5968 }
6069#endif
6170
71+ result->vocab_dft_compatible = common_speculative_are_compatible (ctx_tgt, ctx_dft);
72+ LOG_DBG (" vocab_dft_compatible = %d\n " , result->vocab_dft_compatible );
73+
6274 return result;
6375}
6476
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
7587}
7688
7789bool common_speculative_are_compatible (
78- const struct llama_context * ctx_tgt,
79- const struct llama_context * ctx_dft) {
90+ const struct llama_context * ctx_tgt,
91+ const struct llama_context * ctx_dft) {
8092 const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
8193 const struct llama_model * model_dft = llama_get_model (ctx_dft);
8294
@@ -90,40 +102,41 @@ bool common_speculative_are_compatible(
90102 LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
91103
92104 if (vocab_type_tgt != vocab_type_dft) {
93- LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
94- " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__ , vocab_type_dft, vocab_type_tgt);
105+ LOG_DBG (" %s: draft model vocab type must match target model to use speculation but " , __func__);
106+ LOG_DBG ( " vocab_type_dft = %d while vocab_type_tgt = %d\n " , vocab_type_dft, vocab_type_tgt);
95107 return false ;
96108 }
97109
98- if (llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
110+ if (
111+ llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
99112 llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
100113 llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft) ||
101- llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)) {
102- LOG_ERR (" %s: draft vocab special tokens must match target vocab to use speculation\n " , __func__);
103- LOG_ERR (" %s: tgt: bos = %d (%d), eos = %d (%d)\n " , __func__, llama_vocab_bos (vocab_tgt), llama_vocab_get_add_bos (vocab_tgt), llama_vocab_eos (vocab_tgt), llama_vocab_get_add_eos (vocab_tgt));
104- LOG_ERR (" %s: dft: bos = %d (%d), eos = %d (%d)\n " , __func__, llama_vocab_bos (vocab_dft), llama_vocab_get_add_bos (vocab_dft), llama_vocab_eos (vocab_dft), llama_vocab_get_add_eos (vocab_dft));
114+ llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)
115+ ) {
116+ LOG_DBG (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
105117 return false ;
106118 }
107119
108120 {
109121 const int n_vocab_tgt = llama_vocab_n_tokens (vocab_tgt);
110122 const int n_vocab_dft = llama_vocab_n_tokens (vocab_dft);
111-
112- const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
123+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
124+ ? n_vocab_tgt - n_vocab_dft
125+ : n_vocab_dft - n_vocab_tgt;
113126
114127 if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115- LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
116- " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
117- __func__, n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
128+ LOG_DBG (" %s: draft model vocab must closely match target model to use speculation but " , __func__);
129+ LOG_DBG ( " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
130+ n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118131 return false ;
119132 }
120133
121134 for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
122135 const char * token_text_tgt = llama_vocab_get_text (vocab_tgt, i);
123136 const char * token_text_dft = llama_vocab_get_text (vocab_dft, i);
124137 if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
125- LOG_ERR (" %s: draft vocab vocab must match target vocab to use speculation but "
126- " token %d content differs - target '%s', draft '%s'\n " , __func__ , i,
138+ LOG_DBG (" %s: draft model vocab must match target model to use speculation but " , __func__);
139+ LOG_DBG ( " token %d content differs - target '%s', draft '%s'\n " , i,
127140 common_token_to_piece (ctx_tgt, i).c_str (),
128141 common_token_to_piece (ctx_dft, i).c_str ());
129142 return false ;
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
134147 return true ;
135148}
136149
150+ void common_speculative_add_replacement_tgt_dft (
151+ struct common_speculative * spec,
152+ const char *source, const char *dest) {
153+ spec->tgt_dft_replacements [source] = dest;
154+ }
155+
156+ static std::string replace_to_dft (
157+ struct common_speculative * spec,
158+ const std::string& input) {
159+ std::string result = input;
160+ for (const auto & pair : spec->tgt_dft_replacements ) {
161+ size_t pos = result.find (pair.first );
162+ while (pos != std::string::npos) {
163+ result.replace (pos, pair.first .length (), pair.second );
164+ pos = result.find (pair.first , pos + pair.second .length ());
165+ }
166+ }
167+ return result;
168+ }
169+
170+ static std::string replace_to_tgt (
171+ struct common_speculative * spec,
172+ const std::string& input) {
173+ std::string result = input;
174+ for (const auto & pair : spec->tgt_dft_replacements ) {
175+ size_t pos = result.find (pair.second );
176+ while (pos != std::string::npos) {
177+ result.replace (pos, pair.second .length (), pair.first );
178+ pos = result.find (pair.second , pos + pair.first .length ());
179+ }
180+ }
181+ return result;
182+ }
183+
184+
137185llama_tokens common_speculative_gen_draft (
138186 struct common_speculative * spec,
139187 struct common_speculative_params params,
140- const llama_tokens & prompt_tgt,
188+ const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
141189 llama_token id_last) {
142190 auto & batch = spec->batch ;
143- auto & ctx = spec->ctx ;
191+ auto & ctx_tgt = spec->ctx_tgt ;
192+ auto & ctx_dft = spec->ctx_dft ;
144193 auto & smpl = spec->smpl ;
145- auto & prompt = spec->prompt ;
194+ auto & prompt_dft = spec->prompt_dft ;
146195
147- auto * mem = llama_get_memory (ctx );
196+ auto * mem_dft = llama_get_memory (ctx_dft );
148197
149198 int reuse_i = 0 ;
150199 int reuse_n = 0 ;
151200
152- const int n_ctx = llama_n_ctx (ctx) - params.n_draft ;
201+ const int n_ctx = llama_n_ctx (ctx_dft) - params.n_draft ;
202+
203+ llama_tokens prompt_tgt_draft_model;
204+ if (!spec->vocab_dft_compatible ) {
205+ std::string text;
206+ text = common_detokenize (ctx_tgt, prompt_tgt_main_model, true );
207+ text = replace_to_dft (spec, text);
208+ LOG_DBG (" %s: main->draft detokenized string: '%s'\n " , __func__, text.c_str ());
209+ prompt_tgt_draft_model = common_tokenize (ctx_dft, text, false , true );
210+
211+ // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
212+ const auto * model_tgt = llama_get_model (ctx_tgt);
213+ const auto * vocab_tgt = llama_model_get_vocab (model_tgt);
214+
215+ int32_t n_chars = llama_detokenize (vocab_tgt, &id_last, 1 , nullptr , 0 , false , false );
216+ GGML_ASSERT (n_chars < 0 && " failed to detokenize id_last" );
217+ text.resize (-n_chars);
218+ llama_detokenize (vocab_tgt, &id_last, 1 , text.data (), text.size (), false , false );
219+ text = replace_to_dft (spec, text);
220+
221+ LOG_DBG (" main->draft detokenized id_last(%d): '%s'\n " , id_last, text.c_str ());
222+ id_last = common_tokenize (ctx_dft, text, false , true )[0 ];
223+ }
224+ // prompt_tgt's tokens will always be compatible with ctx_dft
225+ const llama_tokens &prompt_tgt =
226+ spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
153227
154228 const int i_start = std::max<int >(0 , (int ) prompt_tgt.size () - n_ctx);
155229
156230 // reuse as much as possible from the old draft context
157231 // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
158- for (int i = 0 ; i < (int ) prompt .size (); ++i) {
232+ for (int i = 0 ; i < (int ) prompt_dft .size (); ++i) {
159233 int cur = 0 ;
160234 while (i_start + cur < (int ) prompt_tgt.size () &&
161- i + cur < (int ) prompt .size () &&
162- prompt_tgt[i_start + cur] == prompt [i + cur]) {
235+ i + cur < (int ) prompt_dft .size () &&
236+ prompt_tgt[i_start + cur] == prompt_dft [i + cur]) {
163237 cur++;
164238 }
165239
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
169243 }
170244 }
171245
172- LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt .size ());
246+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt_dft .size ());
173247
174248 llama_tokens result;
175249 result.reserve (params.n_draft );
176250
177251 if (reuse_n == 0 ) {
178- llama_memory_clear (mem, false );
179-
180- prompt.clear ();
252+ llama_memory_clear (mem_dft, false );
253+ prompt_dft.clear ();
181254 } else {
182255 // this happens when a previous draft has been discarded (for example, due to being too small), but the
183256 // target model agreed with it. in this case, we simply pass back the previous results to save compute
184- if (reuse_i + reuse_n < (int ) prompt .size () && prompt [reuse_i + reuse_n] == id_last) {
185- for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt .size (); ++i) {
186- result.push_back (prompt [i]);
257+ if (reuse_i + reuse_n < (int ) prompt_dft .size () && prompt_dft [reuse_i + reuse_n] == id_last) {
258+ for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt_dft .size (); ++i) {
259+ result.push_back (prompt_dft [i]);
187260
188261 if (params.n_draft <= (int ) result.size ()) {
189262 break ;
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
194267 }
195268
196269 if (reuse_i > 0 ) {
197- llama_memory_seq_rm (mem , 0 , 0 , reuse_i);
198- llama_memory_seq_add (mem , 0 , reuse_i, -1 , -reuse_i);
270+ llama_memory_seq_rm (mem_dft , 0 , 0 , reuse_i);
271+ llama_memory_seq_add (mem_dft , 0 , reuse_i, -1 , -reuse_i);
199272
200- prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
273+ prompt_dft .erase (prompt_dft .begin (), prompt_dft .begin () + reuse_i);
201274 }
202275
203- if (reuse_n < (int ) prompt.size ()) {
204- llama_memory_seq_rm (mem, 0 , reuse_n, -1 );
205-
206- prompt.erase (prompt.begin () + reuse_n, prompt.end ());
276+ if (reuse_n < (int ) prompt_dft.size ()) {
277+ llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
278+ prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
207279 }
208280 }
209281
@@ -214,42 +286,42 @@ llama_tokens common_speculative_gen_draft(
214286 // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
215287 common_batch_add (batch, prompt_tgt[i], i - i_start, { 0 }, false );
216288
217- prompt .push_back (prompt_tgt[i]);
289+ prompt_dft .push_back (prompt_tgt[i]);
218290 }
219291
220292 // we should rarely end-up here during normal decoding
221293 if (batch.n_tokens > 0 ) {
222294 // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
223295
224- llama_decode (ctx , batch);
296+ llama_decode (ctx_dft , batch);
225297 }
226298
227- const llama_pos n_past = prompt .size ();
299+ const llama_pos n_past = prompt_dft .size ();
228300
229301 LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
230302
231303 common_batch_clear (batch);
232304 common_batch_add (batch, id_last, n_past, { 0 }, true );
233305
234- prompt .push_back (id_last);
306+ prompt_dft .push_back (id_last);
235307
236- // LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt ).c_str());
308+ LOG_DBG (" %s: draft prompt: %s\n " , __func__, string_from (ctx_dft, prompt_dft ).c_str ());
237309
238- llama_decode (ctx , batch);
310+ llama_decode (ctx_dft , batch);
239311
240312 common_sampler_reset (smpl);
241313
242314 // sample n_draft tokens from the draft model
243315 for (int i = 0 ; i < params.n_draft ; ++i) {
244316 common_batch_clear (batch);
245317
246- common_sampler_sample (smpl, ctx , 0 , true );
318+ common_sampler_sample (smpl, ctx_dft , 0 , true );
247319
248320 const auto * cur_p = common_sampler_get_candidates (smpl);
249321
250322 for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
251323 LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
252- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
324+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx_dft , cur_p->data [k].id ).c_str ());
253325 }
254326
255327 // add drafted token for each sequence
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
271343 common_batch_add (batch, id, n_past + i + 1 , { 0 }, true );
272344
273345 // evaluate the drafted tokens on the draft model
274- llama_decode (ctx , batch);
346+ llama_decode (ctx_dft , batch);
275347
276- prompt .push_back (id);
348+ prompt_dft .push_back (id);
277349 }
278350
351+ if (!spec->vocab_dft_compatible ) {
352+ std::string detokenized = common_detokenize (ctx_dft, result, true );
353+ detokenized = replace_to_tgt (spec, detokenized);
354+ LOG_DBG (" draft->main detokenized string: '%s'\n " , detokenized.c_str ());
355+ result = common_tokenize (ctx_tgt, detokenized, false , true );
356+ if (result.size () > (size_t )params.n_draft ) {
357+ result.resize (params.n_draft );
358+ }
359+ }
279360 return result;
280361}
0 commit comments