Skip to content

Commit 160769d

Browse files
committed
accept special tokens when translating between draft/main models
1 parent b9fdf20 commit 160769d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

common/speculative.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,23 +204,23 @@ llama_tokens common_speculative_gen_draft(
204204
const llama_model * model_tgt = llama_get_model(ctx_tgt);
205205

206206
std::string text;
207-
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, false);
207+
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
208208
text = replace_to_dft(spec, text);
209209
LOG_DBG("main->draft detokenized string: '%s'\n", text.c_str());
210-
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, false);
210+
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
211211
text.clear();
212212

213213
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
214214
int32_t n_chars;
215-
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, false);
215+
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, true);
216216
if (n_chars < 0) {
217217
text.resize(-n_chars);
218-
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, false);
218+
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, true);
219219
}
220220
text.resize(n_chars);
221221
text = replace_to_dft(spec, text);
222222
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
223-
id_last = common_tokenize(ctx_dft, text, false, false)[0];
223+
id_last = common_tokenize(ctx_dft, text, false, true)[0];
224224
}
225225
// prompt_tgt's tokens will always be compatible with ctx_dft
226226
const llama_tokens &prompt_tgt =
@@ -350,10 +350,10 @@ llama_tokens common_speculative_gen_draft(
350350
}
351351

352352
if (!spec->vocab_dft_compatible) {
353-
std::string detokenized = common_detokenize(ctx_dft, result, false);
353+
std::string detokenized = common_detokenize(ctx_dft, result, true);
354354
detokenized = replace_to_tgt(spec, detokenized);
355355
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
356-
result = common_tokenize(ctx_tgt, detokenized, false, false);
356+
result = common_tokenize(ctx_tgt, detokenized, false, true);
357357
}
358358
return result;
359359
}

0 commit comments

Comments
 (0)