Skip to content

Commit d9fb3b2

Browse files
committed
speculative : fix the draft sampling
ggml-ci
1 parent be5f611 commit d9fb3b2

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

common/sampling.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320320
return cur_p.data[cur_p.selected].id;
321321
}
322322

323-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
323+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
324324
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
325325

326326
std::vector<llama_token> result;
@@ -330,25 +330,33 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
330330
for (; i < draft.size(); i++) {
331331
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
332332

333+
common_sampler_accept(gsmpl, id, true);
334+
335+
result.push_back(id);
336+
333337
if (draft[i] != id) {
334338
break;
335339
}
340+
}
341+
342+
if (i == draft.size()) {
343+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
344+
345+
common_sampler_accept(gsmpl, id, true);
336346

337347
result.push_back(id);
338348
}
339349

340-
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
341-
342350
return result;
343351
}
344352

345-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
353+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
346354
std::vector<int> idxs(draft.size() + 1);
347355
for (size_t i = 0; i < idxs.size(); ++i) {
348356
idxs[i] = i;
349357
}
350358

351-
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
359+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
352360
}
353361

354362
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,24 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
6262

6363
// generalized version of common_sampler_sample
6464
//
65-
// will cross-reference the sampled tokens with a batch of draft tokens
66-
// if the sampler disagrees at some point, we stop and return the sampled tokens up to now
65+
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
66+
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
6767
//
68-
// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)`
68+
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
69+
//
70+
// is equivalent to
71+
//
72+
// common_sampler_sample(gsmpl, ctx, idx);
73+
// common_sampler_accept(gsmpl, token, true);
6974
//
7075
// requires: idxs.size() == draft.size() + 1
7176
//
7277
// returns at least 1 token, up to idxs.size()
7378
//
74-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
79+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
7580

7681
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
77-
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
82+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
7883

7984
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8085

examples/speculative-simple/speculative-simple.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ int main(int argc, char ** argv) {
163163
// available logits from the batch and sample the next token until we run out of logits or the sampler
164164
// disagrees with the draft
165165
//
166-
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft);
166+
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
167+
168+
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
167169

168170
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
169171

0 commit comments

Comments
 (0)