11#include " arg.h"
22#include " common.h"
33#include " sampling.h"
4- #include " sampling.cpp"
54#include " log.h"
65#include " llama.h"
76
@@ -60,16 +59,15 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
6059
6160 std::vector<int > sample_ids;
6261
63- smpl-> set_logits ( ctx, num_raw_tokens_header - 1 );
62+ gpt_sampler_sample (smpl, ctx, num_raw_tokens_header - 1 , true );
6463 for (int index = num_raw_tokens_header; index < inp.size (); index++)
6564 {
66- auto &cur_p = smpl->cur_p ; // initialized by set_logits
67- llama_sampler_apply (smpl->chain , &cur_p);
65+ auto cur_p = gpt_sampler_get_candidates (smpl); // initialized by set_logits
6866
6967 int match = -1 ;
70- for (int i = 0 ; i < cur_p. size ; i++)
68+ for (int i = 0 ; i < cur_p-> size ; i++)
7169 {
72- auto tok = cur_p. data [i];
70+ auto tok = cur_p-> data [i];
7371 llama_token candidate = tok.id ;
7472 if (candidate == inp[index])
7573 {
@@ -91,7 +89,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
9189 LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
9290 exit (1 );
9391 }
94- smpl-> set_logits ( ctx, 0 );
92+ gpt_sampler_sample (smpl, ctx, 0 , true );
9593 }
9694
9795 // bit pack sample_ids
@@ -247,7 +245,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
247245 exit (1 );
248246 }
249247
250- smpl-> set_logits ( ctx, num_raw_tokens_header - 1 );
248+ gpt_sampler_sample (smpl, ctx, num_raw_tokens_header - 1 , true );
251249
252250 int index = 0 ;
253251 int bit_index = (1 + num_raw_tokens_header * 4 ) * 8 ;
@@ -268,10 +266,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
268266 sample_id |= (int )sample_ids_bitpacked[i + (bit_index / 8 )];
269267 }
270268
271- auto &cur_p = smpl->cur_p ; // initialized by set_logits
272- llama_sampler_apply (smpl->chain , &cur_p);
269+ auto cur_p = gpt_sampler_get_candidates (smpl); // initialized by set_logits
273270
274- auto token_id = cur_p. data [sample_id].id ;
271+ auto token_id = cur_p-> data [sample_id].id ;
275272
276273 out.push_back (token_id);
277274
@@ -303,7 +300,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
303300 LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
304301 exit (1 );
305302 }
306- smpl->set_logits (ctx, 0 );
303+ gpt_sampler_sample (smpl, ctx, 0 , true );
304+
307305 index++;
308306
309307 bit_index += 8 * (fixed_token_cost + bytesize);
@@ -328,10 +326,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
328326 {
329327 int sample_id = id;
330328
331- auto &cur_p = smpl->cur_p ; // initialized by set_logits
332- llama_sampler_apply (smpl->chain , &cur_p);
329+ auto cur_p = gpt_sampler_get_candidates (smpl); // initialized by set_logits
333330
334- auto token_id = cur_p. data [sample_id].id ;
331+ auto token_id = cur_p-> data [sample_id].id ;
335332 out.push_back (token_id);
336333 if (!inp.size () || token_id == inp[num_raw_tokens_header + index])
337334 {
@@ -350,7 +347,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
350347 LOG_ERR (" %s : failed to eval, return code %d\n " , __func__, 1 );
351348 exit (1 );
352349 }
353- smpl-> set_logits ( ctx, 0 );
350+ gpt_sampler_sample (smpl, ctx, 0 , true );
354351 }
355352 index++;
356353
0 commit comments