1+ #include  " speculative.h" 
12#include  " arg.h" 
23#include  " common.h" 
34#include  " sampling.h" 
@@ -102,6 +103,35 @@ int main(int argc, char ** argv) {
102103    auto  * mem_tgt = llama_get_memory (ctx_tgt);
103104    auto  * mem_dft = llama_get_memory (ctx_dft);
104105
106+     //  Check if vocabularies are compatible
107+     bool  vocab_compatible = common_speculative_are_compatible (ctx_tgt, ctx_dft);
108+ 
109+     //  Check vocabulary size difference
110+     if  (vocab_compatible) {
111+         const  int  n_vocab_tgt = llama_vocab_n_tokens (vocab_tgt);
112+         const  int  n_vocab_dft = llama_vocab_n_tokens (vocab_dft);
113+         const  int  vocab_diff = abs (n_vocab_tgt - n_vocab_dft);
114+ 
115+         if  (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
116+             vocab_compatible = false ;
117+             LOG_DBG (" vocab size difference too large: %d vs %d\n "  , n_vocab_tgt, n_vocab_dft);
118+         } else  {
119+             //  Check token consistency for a range of tokens
120+             for  (int  i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
121+                 if  (strcmp (llama_vocab_get_text (vocab_tgt, i), llama_vocab_get_text (vocab_dft, i)) != 0 ) {
122+                     vocab_compatible = false ;
123+                     LOG_DBG (" token %d differs between models\n "  , i);
124+                     break ;
125+                 }
126+             }
127+         }
128+     }
129+ 
130+     if  (!vocab_compatible) {
131+         LOG_INF (" The draft model '%s' is not compatible with the target model '%s'. Tokens will be translated between the draft and target models.\n "  ,
132+                 params.speculative .model .path .c_str (), params.model .path .c_str ());
133+     }
134+ 
105135    //  Tokenize the prompt
106136    std::vector<llama_token> inp;
107137    inp = common_tokenize (ctx_tgt, params.prompt , true , true );
@@ -127,7 +157,16 @@ int main(int argc, char ** argv) {
127157    //  eval the prompt with both models
128158    llama_decode (ctx_tgt, llama_batch_get_one ( inp.data (), n_input - 1 ));
129159    llama_decode (ctx_tgt, llama_batch_get_one (&inp.back (),           1 ));
130-     llama_decode (ctx_dft, llama_batch_get_one ( inp.data (), n_input));
160+ 
161+     //  Handle prompt tokens for draft model
162+     if  (vocab_compatible) {
163+         llama_decode (ctx_dft, llama_batch_get_one (inp.data (), n_input));
164+     } else  {
165+         //  Convert prompt tokens from target to draft model
166+         std::string prompt_text = common_detokenize (ctx_tgt, inp, true );
167+         std::vector<llama_token> inp_dft = common_tokenize (ctx_dft, prompt_text, true , true );
168+         llama_decode (ctx_dft, llama_batch_get_one (inp_dft.data (), inp_dft.size ()));
169+     }
131170
132171    const  auto  t_enc_end = ggml_time_us ();
133172
@@ -224,19 +263,37 @@ int main(int argc, char ** argv) {
224263
225264                        LOG_DBG (" verifying sequence #%d at pos #%d from %d active sequence(s)\n "  , s, i_dft, (int ) active_seqs.size ());
226265                        float  r = u_dist (rng);
227-                         llama_token_data_array dist_dft = { drafts[s].dists [i_dft].data () , drafts[s].dists [i_dft].size (), LLAMA_TOKEN_NULL, true  };
228266
229267                        // GGML_ASSERT(dist_tgt.size <= dist_dft.size);
268+                         llama_token_data_array dist_dft = { drafts[s].dists [i_dft].data () , drafts[s].dists [i_dft].size (), LLAMA_TOKEN_NULL, true  };
230269
231270                        //  acquire the token probabilities assigned by the draft and target models
271+                         llama_token token_tgt = drafts[s].tokens [i_dft];
272+ 
273+                         //  If vocabularies are not compatible, we need to convert the token
274+                         llama_token token_dft = token_tgt;
275+                         if  (!vocab_compatible) {
276+                             //  Convert from target token to draft token by detokenizing and retokenizing
277+                             std::string token_text = common_token_to_piece (ctx_tgt, token_tgt);
278+                             std::vector<llama_token> tokens_dft = common_tokenize (ctx_dft, token_text, false , true );
279+                             if  (!tokens_dft.empty ()) {
280+                                 token_dft = tokens_dft[0 ];
281+                             } else  {
282+                                 //  If conversion fails, skip this token
283+                                 drafts[s].active  = false ;
284+                                 active_seqs.erase (s);
285+                                 continue ;
286+                             }
287+                         }
288+ 
232289                        for  (size_t  i = 0 ; i < dist_tgt.size ; i++) {
233-                             if  (dist_tgt.data [i].id  == drafts[s]. tokens [i_dft] ) {
290+                             if  (dist_tgt.data [i].id  == token_tgt ) {
234291                                p_tgt = dist_tgt.data [i].p ;
235292                                break ;
236293                            }
237294                        }
238295                        for  (size_t  i = 0 ; i < dist_dft.size ; i++) {
239-                             if  (dist_dft.data [i].id  == drafts[s]. tokens [i_dft] ) {
296+                             if  (dist_dft.data [i].id  == token_dft ) {
240297                                p_dft = dist_dft.data [i].p ;
241298                                break ;
242299                            }
@@ -501,25 +558,37 @@ int main(int argc, char ** argv) {
501558
502559                //  add drafted token for each sequence
503560                for  (int  is = 0 ; is < (int ) sa.size (); ++is) {
504-                     const  llama_token id = cur_p->data [is].id ;
505- 
561+                     const  llama_token id_dft = cur_p->data [is].id ;
506562                    const  int  s = sa[is];
507563
508-                     common_sampler_accept (drafts[s].smpl , id , true );
564+                     common_sampler_accept (drafts[s].smpl , id_dft , true );
509565
510-                     drafts[s].tokens .push_back (id);
566+                     //  Convert draft token to target token if vocabularies are not compatible
567+                     llama_token id_tgt = id_dft;
568+                     if  (!vocab_compatible) {
569+                         std::string token_text = common_token_to_piece (ctx_dft, id_dft);
570+                         std::vector<llama_token> tokens_tgt = common_tokenize (ctx_tgt, token_text, false , true );
571+                         if  (!tokens_tgt.empty ()) {
572+                             id_tgt = tokens_tgt[0 ];
573+                         } else  {
574+                             //  If conversion fails, skip this token
575+                             continue ;
576+                         }
577+                     }
578+ 
579+                     drafts[s].tokens .push_back (id_dft);
511580                    //  save cur_p.data into drafts[s].dists
512581                    drafts[s].dists .push_back ({cur_p->data , cur_p->data  + cur_p->size });
513582
514583                    //  add unique drafted tokens to the target batch
515584                    drafts[s].i_batch_tgt .push_back (batch_tgt.n_tokens );
516585
517-                     common_batch_add (batch_tgt, id , n_past_tgt + i + 1 , { s }, true );
586+                     common_batch_add (batch_tgt, id_tgt , n_past_tgt + i + 1 , { s }, true );
518587
519588                    //  add the token to the batch for batched decoding with the draft model
520589                    drafts[s].i_batch_dft  = batch_dft.n_tokens ;
521590
522-                     common_batch_add (batch_dft, id , n_past_cur, { s }, true );
591+                     common_batch_add (batch_dft, id_dft , n_past_cur, { s }, true );
523592
524593                    if  (batch_tgt.n_tokens  > n_draft) {
525594                        drafts[s].drafting  = false ;
@@ -588,6 +657,7 @@ int main(int argc, char ** argv) {
588657    LOG_INF (" target:\n\n "  );
589658    common_perf_print (ctx_tgt, smpl);
590659
660+ 
591661    common_sampler_free (smpl);
592662    for  (int  s = 0 ; s < n_seq_dft; ++s) {
593663        common_sampler_free (drafts[s].smpl );
0 commit comments