@@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320320 return cur_p.data [cur_p.selected ].id ;
321321}
322322
323- std::vector<llama_token> common_sampler_sample_n (struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int > & idxs, const llama_tokens & draft, bool grammar_first) {
323+ std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int > & idxs, const llama_tokens & draft, bool grammar_first) {
324324 GGML_ASSERT (idxs.size () == draft.size () + 1 && " idxs.size() must be draft.size() + 1" );
325325
326326 std::vector<llama_token> result;
@@ -330,25 +330,33 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
330330 for (; i < draft.size (); i++) {
331331 const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i], grammar_first);
332332
333+ common_sampler_accept (gsmpl, id, true );
334+
335+ result.push_back (id);
336+
333337 if (draft[i] != id) {
334338 break ;
335339 }
340+ }
341+
342+ if (i == draft.size ()) {
343+ const llama_token id = common_sampler_sample (gsmpl, ctx, idxs[i], grammar_first);
344+
345+ common_sampler_accept (gsmpl, id, true );
336346
337347 result.push_back (id);
338348 }
339349
340- result.push_back (common_sampler_sample (gsmpl, ctx, idxs[i], grammar_first));
341-
342350 return result;
343351}
344352
345- std::vector<llama_token> common_sampler_sample_n (struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
353+ std::vector<llama_token> common_sampler_sample_and_accept_n (struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
346354 std::vector<int > idxs (draft.size () + 1 );
347355 for (size_t i = 0 ; i < idxs.size (); ++i) {
348356 idxs[i] = i;
349357 }
350358
351- return common_sampler_sample_n (gsmpl, ctx, idxs, draft, grammar_first);
359+ return common_sampler_sample_and_accept_n (gsmpl, ctx, idxs, draft, grammar_first);
352360}
353361
354362uint32_t common_sampler_get_seed (const struct common_sampler * gsmpl) {
0 commit comments