Skip to content

Commit d3fb498

Browse files
committed
implement speculative decoding translation in examples/speculative.cpp (wip)
1 parent 829b762 commit d3fb498

File tree

1 file changed

+80
-10
lines changed

1 file changed

+80
-10
lines changed

examples/speculative/speculative.cpp

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "speculative.h"
12
#include "arg.h"
23
#include "common.h"
34
#include "sampling.h"
@@ -102,6 +103,35 @@ int main(int argc, char ** argv) {
102103
auto * mem_tgt = llama_get_memory(ctx_tgt);
103104
auto * mem_dft = llama_get_memory(ctx_dft);
104105

106+
// Check if vocabularies are compatible
107+
bool vocab_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
108+
109+
// Check vocabulary size difference
110+
if (vocab_compatible) {
111+
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
112+
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
113+
const int vocab_diff = abs(n_vocab_tgt - n_vocab_dft);
114+
115+
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
116+
vocab_compatible = false;
117+
LOG_DBG("vocab size difference too large: %d vs %d\n", n_vocab_tgt, n_vocab_dft);
118+
} else {
119+
// Check token consistency for a range of tokens
120+
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
121+
if (strcmp(llama_vocab_get_text(vocab_tgt, i), llama_vocab_get_text(vocab_dft, i)) != 0) {
122+
vocab_compatible = false;
123+
LOG_DBG("token %d differs between models\n", i);
124+
break;
125+
}
126+
}
127+
}
128+
}
129+
130+
if (!vocab_compatible) {
131+
LOG_INF("The draft model '%s' is not compatible with the target model '%s'. Tokens will be translated between the draft and target models.\n",
132+
params.speculative.model.path.c_str(), params.model.path.c_str());
133+
}
134+
105135
// Tokenize the prompt
106136
std::vector<llama_token> inp;
107137
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
@@ -127,7 +157,16 @@ int main(int argc, char ** argv) {
127157
// eval the prompt with both models
128158
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
129159
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
130-
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
160+
161+
// Handle prompt tokens for draft model
162+
if (vocab_compatible) {
163+
llama_decode(ctx_dft, llama_batch_get_one(inp.data(), n_input));
164+
} else {
165+
// Convert prompt tokens from target to draft model
166+
std::string prompt_text = common_detokenize(ctx_tgt, inp, true);
167+
std::vector<llama_token> inp_dft = common_tokenize(ctx_dft, prompt_text, true, true);
168+
llama_decode(ctx_dft, llama_batch_get_one(inp_dft.data(), inp_dft.size()));
169+
}
131170

132171
const auto t_enc_end = ggml_time_us();
133172

@@ -224,19 +263,37 @@ int main(int argc, char ** argv) {
224263

225264
LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
226265
float r = u_dist(rng);
227-
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
228266

229267
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
268+
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
230269

231270
// acquire the token probabilities assigned by the draft and target models
271+
llama_token token_tgt = drafts[s].tokens[i_dft];
272+
273+
// If vocabularies are not compatible, we need to convert the token
274+
llama_token token_dft = token_tgt;
275+
if (!vocab_compatible) {
276+
// Convert from target token to draft token by detokenizing and retokenizing
277+
std::string token_text = common_token_to_piece(ctx_tgt, token_tgt);
278+
std::vector<llama_token> tokens_dft = common_tokenize(ctx_dft, token_text, false, true);
279+
if (!tokens_dft.empty()) {
280+
token_dft = tokens_dft[0];
281+
} else {
282+
// If conversion fails, skip this token
283+
drafts[s].active = false;
284+
active_seqs.erase(s);
285+
continue;
286+
}
287+
}
288+
232289
for (size_t i = 0; i < dist_tgt.size; i++) {
233-
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
290+
if (dist_tgt.data[i].id == token_tgt) {
234291
p_tgt = dist_tgt.data[i].p;
235292
break;
236293
}
237294
}
238295
for (size_t i = 0; i < dist_dft.size; i++) {
239-
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
296+
if (dist_dft.data[i].id == token_dft) {
240297
p_dft = dist_dft.data[i].p;
241298
break;
242299
}
@@ -501,25 +558,37 @@ int main(int argc, char ** argv) {
501558

502559
// add drafted token for each sequence
503560
for (int is = 0; is < (int) sa.size(); ++is) {
504-
const llama_token id = cur_p->data[is].id;
505-
561+
const llama_token id_dft = cur_p->data[is].id;
506562
const int s = sa[is];
507563

508-
common_sampler_accept(drafts[s].smpl, id, true);
564+
common_sampler_accept(drafts[s].smpl, id_dft, true);
509565

510-
drafts[s].tokens.push_back(id);
566+
// Convert draft token to target token if vocabularies are not compatible
567+
llama_token id_tgt = id_dft;
568+
if (!vocab_compatible) {
569+
std::string token_text = common_token_to_piece(ctx_dft, id_dft);
570+
std::vector<llama_token> tokens_tgt = common_tokenize(ctx_tgt, token_text, false, true);
571+
if (!tokens_tgt.empty()) {
572+
id_tgt = tokens_tgt[0];
573+
} else {
574+
// If conversion fails, skip this token
575+
continue;
576+
}
577+
}
578+
579+
drafts[s].tokens.push_back(id_dft);
511580
// save cur_p.data into drafts[s].dists
512581
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
513582

514583
// add unique drafted tokens to the target batch
515584
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
516585

517-
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
586+
common_batch_add(batch_tgt, id_tgt, n_past_tgt + i + 1, { s }, true);
518587

519588
// add the token to the batch for batched decoding with the draft model
520589
drafts[s].i_batch_dft = batch_dft.n_tokens;
521590

522-
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
591+
common_batch_add(batch_dft, id_dft, n_past_cur, { s }, true);
523592

524593
if (batch_tgt.n_tokens > n_draft) {
525594
drafts[s].drafting = false;
@@ -588,6 +657,7 @@ int main(int argc, char ** argv) {
588657
LOG_INF("target:\n\n");
589658
common_perf_print(ctx_tgt, smpl);
590659

660+
591661
common_sampler_free(smpl);
592662
for (int s = 0; s < n_seq_dft; ++s) {
593663
common_sampler_free(drafts[s].smpl);

0 commit comments

Comments
 (0)