Skip to content

Commit da444fa

Browse files
committed
compress: remove sampling.cpp dependency
1 parent bec8398 commit da444fa

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

examples/compress/compress.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "arg.h"
22
#include "common.h"
33
#include "sampling.h"
4-
#include "sampling.cpp"
54
#include "log.h"
65
#include "llama.h"
76

@@ -60,16 +59,15 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
6059

6160
std::vector<int> sample_ids;
6261

63-
smpl->set_logits(ctx, num_raw_tokens_header - 1);
62+
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
6463
for (int index = num_raw_tokens_header; index < inp.size(); index++)
6564
{
66-
auto &cur_p = smpl->cur_p; // initialized by set_logits
67-
llama_sampler_apply(smpl->chain, &cur_p);
65+
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
6866

6967
int match = -1;
70-
for (int i = 0; i < cur_p.size; i++)
68+
for (int i = 0; i < cur_p->size; i++)
7169
{
72-
auto tok = cur_p.data[i];
70+
auto tok = cur_p->data[i];
7371
llama_token candidate = tok.id;
7472
if (candidate == inp[index])
7573
{
@@ -91,7 +89,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
9189
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
9290
exit(1);
9391
}
94-
smpl->set_logits(ctx, 0);
92+
gpt_sampler_sample(smpl, ctx, 0, true);
9593
}
9694

9795
// bit pack sample_ids
@@ -247,7 +245,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
247245
exit(1);
248246
}
249247

250-
smpl->set_logits(ctx, num_raw_tokens_header - 1);
248+
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
251249

252250
int index = 0;
253251
int bit_index = (1 + num_raw_tokens_header * 4) * 8;
@@ -268,10 +266,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
268266
sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)];
269267
}
270268

271-
auto &cur_p = smpl->cur_p; // initialized by set_logits
272-
llama_sampler_apply(smpl->chain, &cur_p);
269+
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
273270

274-
auto token_id = cur_p.data[sample_id].id;
271+
auto token_id = cur_p->data[sample_id].id;
275272

276273
out.push_back(token_id);
277274

@@ -303,7 +300,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
303300
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
304301
exit(1);
305302
}
306-
smpl->set_logits(ctx, 0);
303+
gpt_sampler_sample(smpl, ctx, 0, true);
304+
307305
index++;
308306

309307
bit_index += 8 * (fixed_token_cost + bytesize);
@@ -328,10 +326,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
328326
{
329327
int sample_id = id;
330328

331-
auto &cur_p = smpl->cur_p; // initialized by set_logits
332-
llama_sampler_apply(smpl->chain, &cur_p);
329+
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
333330

334-
auto token_id = cur_p.data[sample_id].id;
331+
auto token_id = cur_p->data[sample_id].id;
335332
out.push_back(token_id);
336333
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
337334
{
@@ -350,7 +347,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
350347
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
351348
exit(1);
352349
}
353-
smpl->set_logits(ctx, 0);
350+
gpt_sampler_sample(smpl, ctx, 0, true);
354351
}
355352
index++;
356353

0 commit comments

Comments
 (0)