Skip to content

Commit 3fb701d

Browse files
committed
implement tokenizer
1 parent 7d2b818 commit 3fb701d

File tree

1 file changed

+107
-4
lines changed

1 file changed

+107
-4
lines changed

src/llama-sampling.cpp

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,14 +2350,16 @@ struct llama_sampler_llg {
23502350
const struct llama_model * model;
23512351
std::string grammar_kind;
23522352
std::string grammar_data;
2353+
LlgTokenizer *tokenizer;
23532354
LlgConstraint *grammar;
23542355
LlgMaskResult llg_res;
23552356
bool has_llg_res;
23562357
};
23572358

2358-
static LlgConstraint *llama_sampler_llg_new(const char * grammar_kind, const char * grammar_data) {
2359+
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
2360+
const char * grammar_kind, const char * grammar_data) {
23592361
LlgConstraintInit cinit;
2360-
llg_constraint_init_set_defaults(&cinit, nullptr);
2362+
llg_constraint_init_set_defaults(&cinit, tokenizer);
23612363
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
23622364
if (llg_get_error(c)) {
23632365
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c));
@@ -2418,7 +2420,7 @@ static void llama_sampler_llg_reset(struct llama_sampler * smpl) {
24182420
return;
24192421
}
24202422

2421-
auto * grammar_new = llama_sampler_llg_new(ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
2423+
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
24222424
llg_free_constraint(ctx->grammar);
24232425
ctx->grammar = grammar_new;
24242426
ctx->has_llg_res = false;
@@ -2437,6 +2439,7 @@ static struct llama_sampler * llama_sampler_llg_clone(const struct llama_sampler
24372439
result_ctx->grammar_kind = ctx->grammar_kind;
24382440
result_ctx->grammar_data = ctx->grammar_data;
24392441
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
2442+
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
24402443
}
24412444
}
24422445

@@ -2448,6 +2451,7 @@ static void llama_sampler_llg_free(struct llama_sampler * smpl) {
24482451

24492452
if (ctx->grammar) {
24502453
llg_free_constraint(ctx->grammar);
2454+
llg_free_tokenizer(ctx->tokenizer);
24512455
}
24522456

24532457
delete ctx;
@@ -2462,16 +2466,114 @@ static struct llama_sampler_i llama_sampler_llg_i = {
24622466
/* .free = */ llama_sampler_llg_free,
24632467
};
24642468

2469+
2470+
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
2471+
const uint8_t *bytes,
2472+
size_t bytes_len,
2473+
uint32_t *output_tokens,
2474+
size_t output_tokens_len)
2475+
{
2476+
const struct llama_model *model = (const struct llama_model *)user_data;
2477+
int r = llama_tokenize(model, (const char *) bytes, bytes_len,
2478+
(int32_t*)output_tokens, output_tokens_len, false, true);
2479+
if (r < 0)
2480+
return -r;
2481+
return r;
2482+
}
2483+
2484+
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * model) {
2485+
// TODO store the tokenizer in the model somehow
2486+
static const struct llama_model *model_cache;
2487+
static LlgTokenizer *tokenizer_cache;
2488+
2489+
if (model_cache == model) {
2490+
return llg_clone_tokenizer(tokenizer_cache);
2491+
}
2492+
2493+
auto tok_eos = llama_token_eot(model);
2494+
if (tok_eos == LLAMA_TOKEN_NULL)
2495+
tok_eos = llama_token_eos(model);
2496+
2497+
size_t vocab_size = llama_n_vocab(model);
2498+
2499+
auto token_lens = new uint32_t[vocab_size];
2500+
// we typically have ~7 bytes per token; let's go on the safe side here
2501+
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
2502+
auto token_bytes = new uint8_t[token_bytes_size];
2503+
2504+
size_t offset = 0;
2505+
for (size_t i = 0; i < vocab_size; i++) {
2506+
size_t max_token = 1024;
2507+
if (token_bytes_size - offset < max_token) {
2508+
GGML_ABORT("token_bytes buffer too small\n");
2509+
}
2510+
2511+
llama_token token = i;
2512+
auto dp = (char *) token_bytes + offset;
2513+
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false);
2514+
if (size < 0) {
2515+
GGML_ABORT("llama_detokenize failed\n");
2516+
}
2517+
if (size == 0) {
2518+
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true);
2519+
if (size < 0) {
2520+
GGML_ABORT("llama_detokenize failed\n");
2521+
}
2522+
if (size != 0) {
2523+
*dp = '\xff'; // special token prefix marker
2524+
size += 1;
2525+
}
2526+
}
2527+
2528+
token_lens[i] = size;
2529+
offset += size;
2530+
}
2531+
2532+
2533+
LlgTokenizerInit tinit = {
2534+
/* .vocab_size = */ (uint32_t)vocab_size,
2535+
/* .tok_eos = */ (uint32_t)tok_eos,
2536+
/* .token_lens = */ token_lens,
2537+
/* .token_bytes = */ token_bytes,
2538+
/* .tokenizer_json = */ nullptr,
2539+
/* .tokenize_assumes_string = */ false,
2540+
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
2541+
/* .use_approximate_greedy_tokenize_fn = */ false,
2542+
/* .tokenize_user_data = */ model,
2543+
};
2544+
2545+
char error_buffer[1024];
2546+
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
2547+
2548+
delete[] token_bytes;
2549+
delete[] token_lens;
2550+
2551+
if (tokenizer == nullptr) {
2552+
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer);
2553+
return tokenizer;
2554+
}
2555+
2556+
if (tokenizer_cache) {
2557+
llg_free_tokenizer(tokenizer_cache);
2558+
}
2559+
model_cache = model;
2560+
tokenizer_cache = tokenizer;
2561+
2562+
return tokenizer;
2563+
}
2564+
24652565
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
24662566
const char * grammar_kind, const char * grammar_data) {
24672567
auto * ctx = new llama_sampler_llg;
24682568

24692569
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
2570+
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
24702571
*ctx = {
24712572
/* .model = */ model,
24722573
/* .grammar_kind = */ grammar_kind,
24732574
/* .grammar_data = */ grammar_data,
2474-
/* .grammar = */ llama_sampler_llg_new(grammar_kind, grammar_data),
2575+
/* .tokenizer = */ tokenizer,
2576+
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
24752577
/* .llg_res = */ {},
24762578
/* .has_llg_res = */ false,
24772579
};
@@ -2480,6 +2582,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
24802582
/* .model = */ model,
24812583
/* .grammar_kind = */ {},
24822584
/* .grammar_data = */ {},
2585+
/* .tokenizer = */ nullptr,
24832586
/* .grammar = */ nullptr,
24842587
/* .llg_res = */ {},
24852588
/* .has_llg_res = */ false,

0 commit comments

Comments
 (0)