11#include  " speculative.h" 
22
3+ #include  " ggml.h" 
4+ #include  " llama.h" 
35#include  " log.h" 
46#include  " common.h" 
57#include  " sampling.h" 
68
79#include  < cstring> 
810#include  < algorithm> 
11+ #include  < map> 
912
1013#define  SPEC_VOCAB_MAX_SIZE_DIFFERENCE   128 
1114#define  SPEC_VOCAB_CHECK_START_TOKEN_ID  5 
1215
1316struct  common_speculative  {
14-     struct  llama_context  * ctx;
17+     struct  llama_context  * ctx_tgt; //  only used for retokenizing from ctx_dft
18+     struct  llama_context  * ctx_dft;
1519    struct  common_sampler  * smpl;
1620
1721    llama_batch batch;
18-     llama_tokens prompt;
22+     llama_tokens prompt_dft;
23+     bool  vocab_dft_compatible = true ; //  whether retokenization is needed
24+     std::map<std::string, std::string> tgt_dft_replacements = {};
1925};
2026
2127struct  common_speculative  * common_speculative_init (
28+         struct  llama_context  * ctx_tgt,
2229        struct  llama_context  * ctx_dft) {
2330    auto  * result = new  common_speculative {
24-         /*  .ctx    = */   ctx_dft,
25-         /*  .smpl   = */   nullptr ,
26-         /*  .batch  = */   llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
27-         /*  .prompt = */   {},
31+         /*  .ctx_tgt    = */   ctx_tgt,
32+         /*  .ctx_dft    = */   ctx_dft,
33+         /*  .smpl       = */   nullptr ,
34+         /*  .batch      = */   llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
35+         /*  .prompt_dft = */   {},
36+         /*  .vocab_dft_compatible = */   false ,
2837    };
2938
3039    //  TODO: optimize or pass from outside?
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
5968    }
6069#endif 
6170
71+     result->vocab_dft_compatible  = common_speculative_are_compatible (ctx_tgt, ctx_dft);
72+     LOG_DBG (" vocab_dft_compatible = %d\n "  , result->vocab_dft_compatible );
73+ 
6274    return  result;
6375}
6476
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
7587}
7688
7789bool  common_speculative_are_compatible (
78-          const  struct  llama_context  * ctx_tgt,
79-          const  struct  llama_context  * ctx_dft) {
90+     const  struct  llama_context  * ctx_tgt,
91+     const  struct  llama_context  * ctx_dft) {
8092    const  struct  llama_model  * model_tgt = llama_get_model (ctx_tgt);
8193    const  struct  llama_model  * model_dft = llama_get_model (ctx_dft);
8294
@@ -90,40 +102,41 @@ bool common_speculative_are_compatible(
90102    LOG_DBG (" %s: vocab_type dft: %d\n "  , __func__, vocab_type_dft);
91103
92104    if  (vocab_type_tgt != vocab_type_dft) {
93-         LOG_ERR (" %s: draft model vocab type must match target model to use speculation but " 
94-                       " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__ , vocab_type_dft, vocab_type_tgt);
105+         LOG_DBG (" %s: draft model vocab type must match target model to use speculation but " , __func__); 
106+         LOG_DBG ( " vocab_type_dft = %d while vocab_type_tgt = %d\n "  , vocab_type_dft, vocab_type_tgt);
95107        return  false ;
96108    }
97109
98-     if  (llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
110+     if  (
111+         llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
99112        llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
100113        llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft) ||
101-         llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)) {
102-         LOG_ERR (" %s: draft vocab special tokens must match target vocab to use speculation\n "  , __func__);
103-         LOG_ERR (" %s: tgt: bos = %d (%d), eos = %d (%d)\n "  , __func__, llama_vocab_bos (vocab_tgt), llama_vocab_get_add_bos (vocab_tgt), llama_vocab_eos (vocab_tgt), llama_vocab_get_add_eos (vocab_tgt));
104-         LOG_ERR (" %s: dft: bos = %d (%d), eos = %d (%d)\n "  , __func__, llama_vocab_bos (vocab_dft), llama_vocab_get_add_bos (vocab_dft), llama_vocab_eos (vocab_dft), llama_vocab_get_add_eos (vocab_dft));
114+         llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)
115+     ) {
116+         LOG_DBG (" %s: draft model special tokens must match target model to use speculation\n "  , __func__);
105117        return  false ;
106118    }
107119
108120    {
109121        const  int  n_vocab_tgt = llama_vocab_n_tokens (vocab_tgt);
110122        const  int  n_vocab_dft = llama_vocab_n_tokens (vocab_dft);
111- 
112-         const  int  vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
123+         const  int  vocab_diff  = n_vocab_tgt > n_vocab_dft
124+             ? n_vocab_tgt - n_vocab_dft
125+             : n_vocab_dft - n_vocab_tgt;
113126
114127        if  (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115-             LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but " 
116-                           " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n "  ,
117-                     __func__,  n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
128+             LOG_DBG (" %s: draft model vocab must closely match target model to use speculation but " , __func__); 
129+             LOG_DBG ( " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n "  ,
130+                     n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118131            return  false ;
119132        }
120133
121134        for  (int  i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
122135            const  char  * token_text_tgt = llama_vocab_get_text (vocab_tgt, i);
123136            const  char  * token_text_dft = llama_vocab_get_text (vocab_dft, i);
124137            if  (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
125-                 LOG_ERR (" %s: draft vocab  vocab must match target vocab  to use speculation but " 
126-                               " token %d content differs - target '%s', draft '%s'\n " , __func__ , i,
138+                 LOG_DBG (" %s: draft model  vocab must match target model  to use speculation but " , __func__); 
139+                 LOG_DBG ( " token %d content differs - target '%s', draft '%s'\n "  , i,
127140                        common_token_to_piece (ctx_tgt, i).c_str (),
128141                        common_token_to_piece (ctx_dft, i).c_str ());
129142                return  false ;
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
134147    return  true ;
135148}
136149
150+ void  common_speculative_add_replacement_tgt_dft (
151+         struct  common_speculative  * spec,
152+         const  char  *source, const  char  *dest) {
153+     spec->tgt_dft_replacements [source] = dest;
154+ }
155+ 
156+ static  std::string replace_to_dft (
157+         struct  common_speculative  * spec,
158+         const  std::string& input) {
159+     std::string result = input;
160+     for  (const  auto  & pair : spec->tgt_dft_replacements ) {
161+         size_t  pos = result.find (pair.first );
162+         while  (pos != std::string::npos) {
163+             result.replace (pos, pair.first .length (), pair.second );
164+             pos = result.find (pair.first , pos + pair.second .length ());
165+         }
166+     }
167+     return  result;
168+ }
169+ 
170+ static  std::string replace_to_tgt (
171+         struct  common_speculative  * spec,
172+         const  std::string& input) {
173+     std::string result = input;
174+     for  (const  auto & pair : spec->tgt_dft_replacements ) {
175+         size_t  pos = result.find (pair.second );
176+         while  (pos != std::string::npos) {
177+             result.replace (pos, pair.second .length (), pair.first );
178+             pos = result.find (pair.second , pos + pair.first .length ());
179+         }
180+     }
181+     return  result;
182+ }
183+ 
184+ 
137185llama_tokens common_speculative_gen_draft (
138186        struct  common_speculative  * spec,
139187        struct  common_speculative_params  params,
140-         const  llama_tokens & prompt_tgt, 
188+         const  llama_tokens & prompt_tgt_main_model,  //  specified in target model vocab 
141189        llama_token id_last) {
142190    auto  & batch  = spec->batch ;
143-     auto  & ctx    = spec->ctx ;
191+     auto  & ctx_tgt = spec->ctx_tgt ;
192+     auto  & ctx_dft = spec->ctx_dft ;
144193    auto  & smpl   = spec->smpl ;
145-     auto  & prompt  = spec->prompt ;
194+     auto  & prompt_dft  = spec->prompt_dft ;
146195
147-     auto  * mem  = llama_get_memory (ctx );
196+     auto  * mem_dft  = llama_get_memory (ctx_dft );
148197
149198    int  reuse_i = 0 ;
150199    int  reuse_n = 0 ;
151200
152-     const  int  n_ctx = llama_n_ctx (ctx) - params.n_draft ;
201+     const  int  n_ctx = llama_n_ctx (ctx_dft) - params.n_draft ;
202+ 
203+     llama_tokens prompt_tgt_draft_model;
204+     if  (!spec->vocab_dft_compatible ) {
205+         std::string text;
206+         text = common_detokenize (ctx_tgt, prompt_tgt_main_model, true );
207+         text = replace_to_dft (spec, text);
208+         LOG_DBG (" %s: main->draft detokenized string: '%s'\n "  , __func__, text.c_str ());
209+         prompt_tgt_draft_model = common_tokenize (ctx_dft, text, false , true );
210+ 
211+         //  convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
212+         const  auto  * model_tgt = llama_get_model (ctx_tgt);
213+         const  auto  * vocab_tgt = llama_model_get_vocab (model_tgt);
214+ 
215+         int32_t  n_chars = llama_detokenize (vocab_tgt, &id_last, 1 , nullptr , 0 , false , false );
216+         GGML_ASSERT (n_chars < 0  && " failed to detokenize id_last"  );
217+         text.resize (-n_chars);
218+         llama_detokenize (vocab_tgt, &id_last, 1 , text.data (), text.size (), false , false );
219+         text = replace_to_dft (spec, text);
220+ 
221+         LOG_DBG (" main->draft detokenized id_last(%d): '%s'\n "  , id_last, text.c_str ());
222+         id_last = common_tokenize (ctx_dft, text, false , true )[0 ];
223+     }
224+     //  prompt_tgt's tokens will always be compatible with ctx_dft
225+     const  llama_tokens &prompt_tgt =
226+         spec->vocab_dft_compatible  ? prompt_tgt_main_model : prompt_tgt_draft_model;
153227
154228    const  int  i_start = std::max<int >(0 , (int ) prompt_tgt.size () - n_ctx);
155229
156230    //  reuse as much as possible from the old draft context
157231    //  ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
158-     for  (int  i = 0 ; i < (int ) prompt .size (); ++i) {
232+     for  (int  i = 0 ; i < (int ) prompt_dft .size (); ++i) {
159233        int  cur = 0 ;
160234        while  (i_start + cur < (int ) prompt_tgt.size () &&
161-                i       + cur < (int ) prompt .size () &&
162-                prompt_tgt[i_start + cur] == prompt [i + cur]) {
235+                i       + cur < (int ) prompt_dft .size () &&
236+                prompt_tgt[i_start + cur] == prompt_dft [i + cur]) {
163237            cur++;
164238        }
165239
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
169243        }
170244    }
171245
172-     LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n "  , __func__, reuse_i, reuse_n, (int ) prompt .size ());
246+     LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n "  , __func__, reuse_i, reuse_n, (int ) prompt_dft .size ());
173247
174248    llama_tokens result;
175249    result.reserve (params.n_draft );
176250
177251    if  (reuse_n == 0 ) {
178-         llama_memory_clear (mem, false );
179- 
180-         prompt.clear ();
252+         llama_memory_clear (mem_dft, false );
253+         prompt_dft.clear ();
181254    } else  {
182255        //  this happens when a previous draft has been discarded (for example, due to being too small), but the
183256        //  target model agreed with it. in this case, we simply pass back the previous results to save compute
184-         if  (reuse_i + reuse_n < (int ) prompt .size () && prompt [reuse_i + reuse_n] == id_last) {
185-             for  (int  i = reuse_i + reuse_n + 1 ; i < (int ) prompt .size (); ++i) {
186-                 result.push_back (prompt [i]);
257+         if  (reuse_i + reuse_n < (int ) prompt_dft .size () && prompt_dft [reuse_i + reuse_n] == id_last) {
258+             for  (int  i = reuse_i + reuse_n + 1 ; i < (int ) prompt_dft .size (); ++i) {
259+                 result.push_back (prompt_dft [i]);
187260
188261                if  (params.n_draft  <= (int ) result.size ()) {
189262                    break ;
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
194267        }
195268
196269        if  (reuse_i > 0 ) {
197-             llama_memory_seq_rm  (mem , 0 , 0 , reuse_i);
198-             llama_memory_seq_add (mem , 0 , reuse_i, -1 , -reuse_i);
270+             llama_memory_seq_rm  (mem_dft , 0 , 0 , reuse_i);
271+             llama_memory_seq_add (mem_dft , 0 , reuse_i, -1 , -reuse_i);
199272
200-             prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
273+             prompt_dft .erase (prompt_dft .begin (), prompt_dft .begin () + reuse_i);
201274        }
202275
203-         if  (reuse_n < (int ) prompt.size ()) {
204-             llama_memory_seq_rm  (mem, 0 , reuse_n, -1 );
205- 
206-             prompt.erase (prompt.begin () + reuse_n, prompt.end ());
276+         if  (reuse_n < (int ) prompt_dft.size ()) {
277+             llama_memory_seq_rm  (mem_dft, 0 , reuse_n, -1 );
278+             prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
207279        }
208280    }
209281
@@ -214,42 +286,42 @@ llama_tokens common_speculative_gen_draft(
214286        // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
215287        common_batch_add (batch, prompt_tgt[i], i - i_start, { 0  }, false );
216288
217-         prompt .push_back (prompt_tgt[i]);
289+         prompt_dft .push_back (prompt_tgt[i]);
218290    }
219291
220292    //  we should rarely end-up here during normal decoding
221293    if  (batch.n_tokens  > 0 ) {
222294        // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
223295
224-         llama_decode (ctx , batch);
296+         llama_decode (ctx_dft , batch);
225297    }
226298
227-     const  llama_pos n_past = prompt .size ();
299+     const  llama_pos n_past = prompt_dft .size ();
228300
229301    LOG_DBG (" %s: n_past = %d\n "  , __func__, n_past);
230302
231303    common_batch_clear (batch);
232304    common_batch_add   (batch, id_last, n_past, { 0  }, true );
233305
234-     prompt .push_back (id_last);
306+     prompt_dft .push_back (id_last);
235307
236-     // LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt ).c_str());
308+     LOG_DBG (" %s: draft prompt: %s\n "  , __func__, string_from (ctx_dft, prompt_dft ).c_str ());
237309
238-     llama_decode (ctx , batch);
310+     llama_decode (ctx_dft , batch);
239311
240312    common_sampler_reset (smpl);
241313
242314    //  sample n_draft tokens from the draft model
243315    for  (int  i = 0 ; i < params.n_draft ; ++i) {
244316        common_batch_clear (batch);
245317
246-         common_sampler_sample (smpl, ctx , 0 , true );
318+         common_sampler_sample (smpl, ctx_dft , 0 , true );
247319
248320        const  auto  * cur_p = common_sampler_get_candidates (smpl);
249321
250322        for  (int  k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
251323            LOG_DBG ("  - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n "  ,
252-                     k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
324+                     k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx_dft , cur_p->data [k].id ).c_str ());
253325        }
254326
255327        //  add drafted token for each sequence
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
271343        common_batch_add (batch, id, n_past + i + 1 , { 0  }, true );
272344
273345        //  evaluate the drafted tokens on the draft model
274-         llama_decode (ctx , batch);
346+         llama_decode (ctx_dft , batch);
275347
276-         prompt .push_back (id);
348+         prompt_dft .push_back (id);
277349    }
278350
351+     if  (!spec->vocab_dft_compatible ) {
352+         std::string detokenized = common_detokenize (ctx_dft, result, true );
353+         detokenized = replace_to_tgt (spec, detokenized);
354+         LOG_DBG (" draft->main detokenized string: '%s'\n "  , detokenized.c_str ());
355+         result = common_tokenize (ctx_tgt, detokenized, false , true );
356+         if  (result.size () > (size_t )params.n_draft ) {
357+             result.resize (params.n_draft );
358+         }
359+     }
279360    return  result;
280361}
0 commit comments