1
1
#include " speculative.h"
2
2
3
+ #include " ggml.h"
4
+ #include " llama.h"
3
5
#include " log.h"
4
6
#include " common.h"
5
7
#include " sampling.h"
6
8
7
9
#include < cstring>
8
10
#include < algorithm>
11
+ #include < map>
9
12
10
13
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
11
14
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
12
15
13
16
struct common_speculative {
14
- struct llama_context * ctx;
17
+ struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
18
+ struct llama_context * ctx_dft;
15
19
struct common_sampler * smpl;
16
20
17
21
llama_batch batch;
18
- llama_tokens prompt;
22
+ llama_tokens prompt_dft;
23
+ bool vocab_dft_compatible = true ; // whether retokenization is needed
24
+ std::map<std::string, std::string> tgt_dft_replacements = {};
19
25
};
20
26
21
27
struct common_speculative * common_speculative_init (
28
+ struct llama_context * ctx_tgt,
22
29
struct llama_context * ctx_dft) {
23
30
auto * result = new common_speculative {
24
- /* .ctx = */ ctx_dft,
25
- /* .smpl = */ nullptr ,
26
- /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
27
- /* .prompt = */ {},
31
+ /* .ctx_tgt = */ ctx_tgt,
32
+ /* .ctx_dft = */ ctx_dft,
33
+ /* .smpl = */ nullptr ,
34
+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
35
+ /* .prompt_dft = */ {},
36
+ /* .vocab_dft_compatible = */ false ,
28
37
};
29
38
30
39
// TODO: optimize or pass from outside?
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
59
68
}
60
69
#endif
61
70
71
+ result->vocab_dft_compatible = common_speculative_are_compatible (ctx_tgt, ctx_dft);
72
+ LOG_DBG (" vocab_dft_compatible = %d\n " , result->vocab_dft_compatible );
73
+
62
74
return result;
63
75
}
64
76
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
75
87
}
76
88
77
89
bool common_speculative_are_compatible (
78
- const struct llama_context * ctx_tgt,
79
- const struct llama_context * ctx_dft) {
90
+ const struct llama_context * ctx_tgt,
91
+ const struct llama_context * ctx_dft) {
80
92
const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
81
93
const struct llama_model * model_dft = llama_get_model (ctx_dft);
82
94
@@ -90,40 +102,41 @@ bool common_speculative_are_compatible(
90
102
LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
91
103
92
104
if (vocab_type_tgt != vocab_type_dft) {
93
- LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
94
- " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__ , vocab_type_dft, vocab_type_tgt);
105
+ LOG_DBG (" %s: draft model vocab type must match target model to use speculation but " , __func__);
106
+ LOG_DBG ( " vocab_type_dft = %d while vocab_type_tgt = %d\n " , vocab_type_dft, vocab_type_tgt);
95
107
return false ;
96
108
}
97
109
98
- if (llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
110
+ if (
111
+ llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
99
112
llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
100
113
llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft) ||
101
- llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)) {
102
- LOG_ERR (" %s: draft vocab special tokens must match target vocab to use speculation\n " , __func__);
103
- LOG_ERR (" %s: tgt: bos = %d (%d), eos = %d (%d)\n " , __func__, llama_vocab_bos (vocab_tgt), llama_vocab_get_add_bos (vocab_tgt), llama_vocab_eos (vocab_tgt), llama_vocab_get_add_eos (vocab_tgt));
104
- LOG_ERR (" %s: dft: bos = %d (%d), eos = %d (%d)\n " , __func__, llama_vocab_bos (vocab_dft), llama_vocab_get_add_bos (vocab_dft), llama_vocab_eos (vocab_dft), llama_vocab_get_add_eos (vocab_dft));
114
+ llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)
115
+ ) {
116
+ LOG_DBG (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
105
117
return false ;
106
118
}
107
119
108
120
{
109
121
const int n_vocab_tgt = llama_vocab_n_tokens (vocab_tgt);
110
122
const int n_vocab_dft = llama_vocab_n_tokens (vocab_dft);
111
-
112
- const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
123
+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
124
+ ? n_vocab_tgt - n_vocab_dft
125
+ : n_vocab_dft - n_vocab_tgt;
113
126
114
127
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115
- LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
116
- " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
117
- __func__, n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
128
+ LOG_DBG (" %s: draft model vocab must closely match target model to use speculation but " , __func__);
129
+ LOG_DBG ( " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
130
+ n_vocab_tgt, llama_vocab_n_tokens (vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118
131
return false ;
119
132
}
120
133
121
134
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
122
135
const char * token_text_tgt = llama_vocab_get_text (vocab_tgt, i);
123
136
const char * token_text_dft = llama_vocab_get_text (vocab_dft, i);
124
137
if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
125
- LOG_ERR (" %s: draft vocab vocab must match target vocab to use speculation but "
126
- " token %d content differs - target '%s', draft '%s'\n " , __func__ , i,
138
+ LOG_DBG (" %s: draft model vocab must match target model to use speculation but " , __func__);
139
+ LOG_DBG ( " token %d content differs - target '%s', draft '%s'\n " , i,
127
140
common_token_to_piece (ctx_tgt, i).c_str (),
128
141
common_token_to_piece (ctx_dft, i).c_str ());
129
142
return false ;
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
134
147
return true ;
135
148
}
136
149
150
+ void common_speculative_add_replacement_tgt_dft (
151
+ struct common_speculative * spec,
152
+ const char *source, const char *dest) {
153
+ spec->tgt_dft_replacements [source] = dest;
154
+ }
155
+
156
+ static std::string replace_to_dft (
157
+ struct common_speculative * spec,
158
+ const std::string& input) {
159
+ std::string result = input;
160
+ for (const auto & pair : spec->tgt_dft_replacements ) {
161
+ size_t pos = result.find (pair.first );
162
+ while (pos != std::string::npos) {
163
+ result.replace (pos, pair.first .length (), pair.second );
164
+ pos = result.find (pair.first , pos + pair.second .length ());
165
+ }
166
+ }
167
+ return result;
168
+ }
169
+
170
+ static std::string replace_to_tgt (
171
+ struct common_speculative * spec,
172
+ const std::string& input) {
173
+ std::string result = input;
174
+ for (const auto & pair : spec->tgt_dft_replacements ) {
175
+ size_t pos = result.find (pair.second );
176
+ while (pos != std::string::npos) {
177
+ result.replace (pos, pair.second .length (), pair.first );
178
+ pos = result.find (pair.second , pos + pair.first .length ());
179
+ }
180
+ }
181
+ return result;
182
+ }
183
+
184
+
137
185
llama_tokens common_speculative_gen_draft (
138
186
struct common_speculative * spec,
139
187
struct common_speculative_params params,
140
- const llama_tokens & prompt_tgt,
188
+ const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
141
189
llama_token id_last) {
142
190
auto & batch = spec->batch ;
143
- auto & ctx = spec->ctx ;
191
+ auto & ctx_tgt = spec->ctx_tgt ;
192
+ auto & ctx_dft = spec->ctx_dft ;
144
193
auto & smpl = spec->smpl ;
145
- auto & prompt = spec->prompt ;
194
+ auto & prompt_dft = spec->prompt_dft ;
146
195
147
- auto * mem = llama_get_memory (ctx );
196
+ auto * mem_dft = llama_get_memory (ctx_dft );
148
197
149
198
int reuse_i = 0 ;
150
199
int reuse_n = 0 ;
151
200
152
- const int n_ctx = llama_n_ctx (ctx) - params.n_draft ;
201
+ const int n_ctx = llama_n_ctx (ctx_dft) - params.n_draft ;
202
+
203
+ llama_tokens prompt_tgt_draft_model;
204
+ if (!spec->vocab_dft_compatible ) {
205
+ std::string text;
206
+ text = common_detokenize (ctx_tgt, prompt_tgt_main_model, true );
207
+ text = replace_to_dft (spec, text);
208
+ LOG_DBG (" %s: main->draft detokenized string: '%s'\n " , __func__, text.c_str ());
209
+ prompt_tgt_draft_model = common_tokenize (ctx_dft, text, false , true );
210
+
211
+ // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
212
+ const auto * model_tgt = llama_get_model (ctx_tgt);
213
+ const auto * vocab_tgt = llama_model_get_vocab (model_tgt);
214
+
215
+ int32_t n_chars = llama_detokenize (vocab_tgt, &id_last, 1 , nullptr , 0 , false , false );
216
+ GGML_ASSERT (n_chars < 0 && " failed to detokenize id_last" );
217
+ text.resize (-n_chars);
218
+ llama_detokenize (vocab_tgt, &id_last, 1 , text.data (), text.size (), false , false );
219
+ text = replace_to_dft (spec, text);
220
+
221
+ LOG_DBG (" main->draft detokenized id_last(%d): '%s'\n " , id_last, text.c_str ());
222
+ id_last = common_tokenize (ctx_dft, text, false , true )[0 ];
223
+ }
224
+ // prompt_tgt's tokens will always be compatible with ctx_dft
225
+ const llama_tokens &prompt_tgt =
226
+ spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
153
227
154
228
const int i_start = std::max<int >(0 , (int ) prompt_tgt.size () - n_ctx);
155
229
156
230
// reuse as much as possible from the old draft context
157
231
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
158
- for (int i = 0 ; i < (int ) prompt .size (); ++i) {
232
+ for (int i = 0 ; i < (int ) prompt_dft .size (); ++i) {
159
233
int cur = 0 ;
160
234
while (i_start + cur < (int ) prompt_tgt.size () &&
161
- i + cur < (int ) prompt .size () &&
162
- prompt_tgt[i_start + cur] == prompt [i + cur]) {
235
+ i + cur < (int ) prompt_dft .size () &&
236
+ prompt_tgt[i_start + cur] == prompt_dft [i + cur]) {
163
237
cur++;
164
238
}
165
239
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
169
243
}
170
244
}
171
245
172
- LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt .size ());
246
+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt_dft .size ());
173
247
174
248
llama_tokens result;
175
249
result.reserve (params.n_draft );
176
250
177
251
if (reuse_n == 0 ) {
178
- llama_memory_clear (mem, false );
179
-
180
- prompt.clear ();
252
+ llama_memory_clear (mem_dft, false );
253
+ prompt_dft.clear ();
181
254
} else {
182
255
// this happens when a previous draft has been discarded (for example, due to being too small), but the
183
256
// target model agreed with it. in this case, we simply pass back the previous results to save compute
184
- if (reuse_i + reuse_n < (int ) prompt .size () && prompt [reuse_i + reuse_n] == id_last) {
185
- for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt .size (); ++i) {
186
- result.push_back (prompt [i]);
257
+ if (reuse_i + reuse_n < (int ) prompt_dft .size () && prompt_dft [reuse_i + reuse_n] == id_last) {
258
+ for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt_dft .size (); ++i) {
259
+ result.push_back (prompt_dft [i]);
187
260
188
261
if (params.n_draft <= (int ) result.size ()) {
189
262
break ;
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
194
267
}
195
268
196
269
if (reuse_i > 0 ) {
197
- llama_memory_seq_rm (mem , 0 , 0 , reuse_i);
198
- llama_memory_seq_add (mem , 0 , reuse_i, -1 , -reuse_i);
270
+ llama_memory_seq_rm (mem_dft , 0 , 0 , reuse_i);
271
+ llama_memory_seq_add (mem_dft , 0 , reuse_i, -1 , -reuse_i);
199
272
200
- prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
273
+ prompt_dft .erase (prompt_dft .begin (), prompt_dft .begin () + reuse_i);
201
274
}
202
275
203
- if (reuse_n < (int ) prompt.size ()) {
204
- llama_memory_seq_rm (mem, 0 , reuse_n, -1 );
205
-
206
- prompt.erase (prompt.begin () + reuse_n, prompt.end ());
276
+ if (reuse_n < (int ) prompt_dft.size ()) {
277
+ llama_memory_seq_rm (mem_dft, 0 , reuse_n, -1 );
278
+ prompt_dft.erase (prompt_dft.begin () + reuse_n, prompt_dft.end ());
207
279
}
208
280
}
209
281
@@ -214,42 +286,42 @@ llama_tokens common_speculative_gen_draft(
214
286
// LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
215
287
common_batch_add (batch, prompt_tgt[i], i - i_start, { 0 }, false );
216
288
217
- prompt .push_back (prompt_tgt[i]);
289
+ prompt_dft .push_back (prompt_tgt[i]);
218
290
}
219
291
220
292
// we should rarely end-up here during normal decoding
221
293
if (batch.n_tokens > 0 ) {
222
294
// LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
223
295
224
- llama_decode (ctx , batch);
296
+ llama_decode (ctx_dft , batch);
225
297
}
226
298
227
- const llama_pos n_past = prompt .size ();
299
+ const llama_pos n_past = prompt_dft .size ();
228
300
229
301
LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
230
302
231
303
common_batch_clear (batch);
232
304
common_batch_add (batch, id_last, n_past, { 0 }, true );
233
305
234
- prompt .push_back (id_last);
306
+ prompt_dft .push_back (id_last);
235
307
236
- // LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt ).c_str());
308
+ LOG_DBG (" %s: draft prompt: %s\n " , __func__, string_from (ctx_dft, prompt_dft ).c_str ());
237
309
238
- llama_decode (ctx , batch);
310
+ llama_decode (ctx_dft , batch);
239
311
240
312
common_sampler_reset (smpl);
241
313
242
314
// sample n_draft tokens from the draft model
243
315
for (int i = 0 ; i < params.n_draft ; ++i) {
244
316
common_batch_clear (batch);
245
317
246
- common_sampler_sample (smpl, ctx , 0 , true );
318
+ common_sampler_sample (smpl, ctx_dft , 0 , true );
247
319
248
320
const auto * cur_p = common_sampler_get_candidates (smpl);
249
321
250
322
for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
251
323
LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
252
- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
324
+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx_dft , cur_p->data [k].id ).c_str ());
253
325
}
254
326
255
327
// add drafted token for each sequence
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
271
343
common_batch_add (batch, id, n_past + i + 1 , { 0 }, true );
272
344
273
345
// evaluate the drafted tokens on the draft model
274
- llama_decode (ctx , batch);
346
+ llama_decode (ctx_dft , batch);
275
347
276
- prompt .push_back (id);
348
+ prompt_dft .push_back (id);
277
349
}
278
350
351
+ if (!spec->vocab_dft_compatible ) {
352
+ std::string detokenized = common_detokenize (ctx_dft, result, true );
353
+ detokenized = replace_to_tgt (spec, detokenized);
354
+ LOG_DBG (" draft->main detokenized string: '%s'\n " , detokenized.c_str ());
355
+ result = common_tokenize (ctx_tgt, detokenized, false , true );
356
+ if (result.size () > (size_t )params.n_draft ) {
357
+ result.resize (params.n_draft );
358
+ }
359
+ }
279
360
return result;
280
361
}
0 commit comments