@@ -99,6 +99,49 @@ int main(int argc, char ** argv) {
9999 const bool vocab_type_dft = llama_vocab_type (vocab_dft);
100100 LOG_DBG (" vocab_type dft: %d\n " , vocab_type_dft);
101101
102+ if (vocab_type_tgt != vocab_type_dft) {
103+ LOG_ERR (" %s: draft model vocab type must match target model to use speculation but " , __func__);
104+ LOG_ERR (" vocab_type_dft = %d while vocab_type_tgt = %d\n " , vocab_type_dft, vocab_type_tgt);
105+ return 1 ;
106+ }
107+
108+ if (
109+ llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
110+ llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
111+ llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft) ||
112+ llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)
113+ ) {
114+ LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
115+ return 1 ;
116+ }
117+
118+ {
119+ const int n_vocab_tgt = llama_vocab_n_tokens (vocab_tgt);
120+ const int n_vocab_dft = llama_vocab_n_tokens (vocab_dft);
121+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
122+ ? n_vocab_tgt - n_vocab_dft
123+ : n_vocab_dft - n_vocab_tgt;
124+
125+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
126+ LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but " , __func__);
127+ LOG_ERR (" target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
128+ n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
129+ return 1 ;
130+ }
131+
132+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
133+ const char * token_text_tgt = llama_vocab_get_text (vocab_tgt, i);
134+ const char * token_text_dft = llama_vocab_get_text (vocab_dft, i);
135+ if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
136+ LOG_ERR (" %s: draft model vocab must match target model to use speculation but " , __func__);
137+ LOG_ERR (" token %d content differs - target '%s', draft '%s'\n " , i,
138+ common_token_to_piece (ctx_tgt, i).c_str (),
139+ common_token_to_piece (ctx_dft, i).c_str ());
140+ return 1 ;
141+ }
142+ }
143+ }
144+
102145 auto * mem_tgt = llama_get_memory (ctx_tgt);
103146 auto * mem_dft = llama_get_memory (ctx_dft);
104147
0 commit comments