@@ -2350,14 +2350,16 @@ struct llama_sampler_llg {
23502350 const struct llama_model * model;
23512351 std::string grammar_kind;
23522352 std::string grammar_data;
2353+ LlgTokenizer *tokenizer;
23532354 LlgConstraint *grammar;
23542355 LlgMaskResult llg_res;
23552356 bool has_llg_res;
23562357};
23572358
2358- static LlgConstraint *llama_sampler_llg_new (const char * grammar_kind, const char * grammar_data) {
2359+ static LlgConstraint *llama_sampler_llg_new (LlgTokenizer *tokenizer,
2360+ const char * grammar_kind, const char * grammar_data) {
23592361 LlgConstraintInit cinit;
2360- llg_constraint_init_set_defaults (&cinit, nullptr );
2362+ llg_constraint_init_set_defaults (&cinit, tokenizer );
23612363 auto c = llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
23622364 if (llg_get_error (c)) {
23632365 LLAMA_LOG_ERROR (" llg error: %s\n " , llg_get_error (c));
@@ -2418,7 +2420,7 @@ static void llama_sampler_llg_reset(struct llama_sampler * smpl) {
24182420 return ;
24192421 }
24202422
2421- auto * grammar_new = llama_sampler_llg_new (ctx->grammar_kind .c_str (), ctx->grammar_data .c_str ());
2423+ auto * grammar_new = llama_sampler_llg_new (ctx->tokenizer , ctx-> grammar_kind .c_str (), ctx->grammar_data .c_str ());
24222424 llg_free_constraint (ctx->grammar );
24232425 ctx->grammar = grammar_new;
24242426 ctx->has_llg_res = false ;
@@ -2437,6 +2439,7 @@ static struct llama_sampler * llama_sampler_llg_clone(const struct llama_sampler
24372439 result_ctx->grammar_kind = ctx->grammar_kind ;
24382440 result_ctx->grammar_data = ctx->grammar_data ;
24392441 result_ctx->grammar = llg_clone_constraint (ctx->grammar );
2442+ result_ctx->tokenizer = llg_clone_tokenizer (ctx->tokenizer );
24402443 }
24412444 }
24422445
@@ -2448,6 +2451,7 @@ static void llama_sampler_llg_free(struct llama_sampler * smpl) {
24482451
24492452 if (ctx->grammar ) {
24502453 llg_free_constraint (ctx->grammar );
2454+ llg_free_tokenizer (ctx->tokenizer );
24512455 }
24522456
24532457 delete ctx;
@@ -2462,16 +2466,114 @@ static struct llama_sampler_i llama_sampler_llg_i = {
24622466 /* .free = */ llama_sampler_llg_free,
24632467};
24642468
2469+
2470+ static size_t llama_sampler_llg_tokenize_fn (const void *user_data,
2471+ const uint8_t *bytes,
2472+ size_t bytes_len,
2473+ uint32_t *output_tokens,
2474+ size_t output_tokens_len)
2475+ {
2476+ const struct llama_model *model = (const struct llama_model *)user_data;
2477+ int r = llama_tokenize (model, (const char *) bytes, bytes_len,
2478+ (int32_t *)output_tokens, output_tokens_len, false , true );
2479+ if (r < 0 )
2480+ return -r;
2481+ return r;
2482+ }
2483+
2484+ static LlgTokenizer *llama_sampler_llg_new_tokenizer (const struct llama_model * model) {
2485+ // TODO store the tokenizer in the model somehow
2486+ static const struct llama_model *model_cache;
2487+ static LlgTokenizer *tokenizer_cache;
2488+
2489+ if (model_cache == model) {
2490+ return llg_clone_tokenizer (tokenizer_cache);
2491+ }
2492+
2493+ auto tok_eos = llama_token_eot (model);
2494+ if (tok_eos == LLAMA_TOKEN_NULL)
2495+ tok_eos = llama_token_eos (model);
2496+
2497+ size_t vocab_size = llama_n_vocab (model);
2498+
2499+ auto token_lens = new uint32_t [vocab_size];
2500+ // we typically have ~7 bytes per token; let's go on the safe side here
2501+ auto token_bytes_size = vocab_size * 16 + 1024 * 1024 ;
2502+ auto token_bytes = new uint8_t [token_bytes_size];
2503+
2504+ size_t offset = 0 ;
2505+ for (size_t i = 0 ; i < vocab_size; i++) {
2506+ size_t max_token = 1024 ;
2507+ if (token_bytes_size - offset < max_token) {
2508+ GGML_ABORT (" token_bytes buffer too small\n " );
2509+ }
2510+
2511+ llama_token token = i;
2512+ auto dp = (char *) token_bytes + offset;
2513+ auto size = llama_detokenize (model, &token, 1 , dp, max_token, false , false );
2514+ if (size < 0 ) {
2515+ GGML_ABORT (" llama_detokenize failed\n " );
2516+ }
2517+ if (size == 0 ) {
2518+ size = llama_detokenize (model, &token, 1 , dp + 1 , max_token - 1 , false , true );
2519+ if (size < 0 ) {
2520+ GGML_ABORT (" llama_detokenize failed\n " );
2521+ }
2522+ if (size != 0 ) {
2523+ *dp = ' \xff ' ; // special token prefix marker
2524+ size += 1 ;
2525+ }
2526+ }
2527+
2528+ token_lens[i] = size;
2529+ offset += size;
2530+ }
2531+
2532+
2533+ LlgTokenizerInit tinit = {
2534+ /* .vocab_size = */ (uint32_t )vocab_size,
2535+ /* .tok_eos = */ (uint32_t )tok_eos,
2536+ /* .token_lens = */ token_lens,
2537+ /* .token_bytes = */ token_bytes,
2538+ /* .tokenizer_json = */ nullptr ,
2539+ /* .tokenize_assumes_string = */ false ,
2540+ /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
2541+ /* .use_approximate_greedy_tokenize_fn = */ false ,
2542+ /* .tokenize_user_data = */ model,
2543+ };
2544+
2545+ char error_buffer[1024 ];
2546+ LlgTokenizer *tokenizer = llg_new_tokenizer (&tinit, error_buffer, sizeof (error_buffer));
2547+
2548+ delete[] token_bytes;
2549+ delete[] token_lens;
2550+
2551+ if (tokenizer == nullptr ) {
2552+ LLAMA_LOG_ERROR (" llg tokenizer error: %s\n " , error_buffer);
2553+ return tokenizer;
2554+ }
2555+
2556+ if (tokenizer_cache) {
2557+ llg_free_tokenizer (tokenizer_cache);
2558+ }
2559+ model_cache = model;
2560+ tokenizer_cache = tokenizer;
2561+
2562+ return tokenizer;
2563+ }
2564+
24652565struct llama_sampler * llama_sampler_init_llg (const struct llama_model * model,
24662566 const char * grammar_kind, const char * grammar_data) {
24672567 auto * ctx = new llama_sampler_llg;
24682568
24692569 if (grammar_kind != nullptr && grammar_kind[0 ] != ' \0 ' ) {
2570+ auto tokenizer = llama_sampler_llg_new_tokenizer (model);
24702571 *ctx = {
24712572 /* .model = */ model,
24722573 /* .grammar_kind = */ grammar_kind,
24732574 /* .grammar_data = */ grammar_data,
2474- /* .grammar = */ llama_sampler_llg_new (grammar_kind, grammar_data),
2575+ /* .tokenizer = */ tokenizer,
2576+ /* .grammar = */ llama_sampler_llg_new (tokenizer, grammar_kind, grammar_data),
24752577 /* .llg_res = */ {},
24762578 /* .has_llg_res = */ false ,
24772579 };
@@ -2480,6 +2582,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
24802582 /* .model = */ model,
24812583 /* .grammar_kind = */ {},
24822584 /* .grammar_data = */ {},
2585+ /* .tokenizer = */ nullptr ,
24832586 /* .grammar = */ nullptr ,
24842587 /* .llg_res = */ {},
24852588 /* .has_llg_res = */ false ,
0 commit comments