1- #ifdef LLAMA_LLGUIDANCE
1+ #ifdef LLAMA_USE_LLGUIDANCE
2+
3+ #include " common.h"
4+ #include " sampling.h"
5+ #include " log.h"
6+ #include " llama.h"
7+
28#include " llguidance.h"
39
410struct llama_sampler_llg {
511 const struct llama_model * model;
12+ const struct llama_vocab * vocab;
613 std::string grammar_kind;
714 std::string grammar_data;
815 LlgTokenizer *tokenizer;
@@ -17,7 +24,7 @@ static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
1724 llg_constraint_init_set_defaults (&cinit, tokenizer);
1825 auto c = llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
1926 if (llg_get_error (c)) {
20- LLAMA_LOG_ERROR (" llg error: %s\n " , llg_get_error (c));
27+ LOG_ERR (" llg error: %s\n " , llg_get_error (c));
2128 llg_free_constraint (c);
2229 return nullptr ;
2330 }
@@ -44,15 +51,15 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
4451 if (llg_compute_mask (ctx->grammar , &ctx->llg_res ) == 0 ) {
4552 ctx->has_llg_res = true ;
4653 } else {
47- LLAMA_LOG_ERROR (" llg error: %s\n " , llg_get_error (ctx->grammar ));
54+ LOG_ERR (" llg error: %s\n " , llg_get_error (ctx->grammar ));
4855 llg_free_constraint (ctx->grammar );
4956 ctx->grammar = nullptr ;
5057 }
5158 }
5259 if (ctx->has_llg_res ) {
5360 if (ctx->llg_res .is_stop ) {
5461 for (size_t i = 0 ; i < cur_p->size ; ++i) {
55- if (!llama_token_is_eog (ctx->model , cur_p->data [i].id )) {
62+ if (!llama_vocab_is_eog (ctx->vocab , cur_p->data [i].id )) {
5663 cur_p->data [i].logit = -INFINITY;
5764 }
5865 }
@@ -128,8 +135,8 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
128135 uint32_t *output_tokens,
129136 size_t output_tokens_len)
130137{
131- const struct llama_model *model = (const struct llama_model *)user_data;
132- int r = llama_tokenize (model , (const char *) bytes, bytes_len,
138+ const struct llama_vocab *vocab = (const struct llama_vocab *)user_data;
139+ int r = llama_tokenize (vocab , (const char *) bytes, bytes_len,
133140 (int32_t *)output_tokens, output_tokens_len, false , true );
134141 if (r < 0 )
135142 return -r;
@@ -145,11 +152,13 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
145152 return llg_clone_tokenizer (tokenizer_cache);
146153 }
147154
148- auto tok_eos = llama_token_eot (model);
155+ const struct llama_vocab *vocab = llama_model_get_vocab (model);
156+
157+ auto tok_eos = llama_vocab_eot (vocab);
149158 if (tok_eos == LLAMA_TOKEN_NULL)
150- tok_eos = llama_token_eos (model );
159+ tok_eos = llama_vocab_eos (vocab );
151160
152- size_t vocab_size = llama_n_vocab (model );
161+ size_t vocab_size = llama_vocab_n_tokens (vocab );
153162
154163 auto token_lens = new uint32_t [vocab_size];
155164 // we typically have ~7 bytes per token; let's go on the safe side here
@@ -165,12 +174,12 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
165174
166175 llama_token token = i;
167176 auto dp = (char *) token_bytes + offset;
168- auto size = llama_detokenize (model , &token, 1 , dp, max_token, false , false );
177+ auto size = llama_detokenize (vocab , &token, 1 , dp, max_token, false , false );
169178 if (size < 0 ) {
170179 GGML_ABORT (" llama_detokenize failed\n " );
171180 }
172181 if (size == 0 ) {
173- size = llama_detokenize (model , &token, 1 , dp + 1 , max_token - 1 , false , true );
182+ size = llama_detokenize (vocab , &token, 1 , dp + 1 , max_token - 1 , false , true );
174183 if (size < 0 ) {
175184 GGML_ABORT (" llama_detokenize failed\n " );
176185 }
@@ -194,7 +203,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
194203 /* .tokenize_assumes_string = */ false ,
195204 /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
196205 /* .use_approximate_greedy_tokenize_fn = */ false ,
197- /* .tokenize_user_data = */ model ,
206+ /* .tokenize_user_data = */ vocab ,
198207 };
199208
200209 char error_buffer[1024 ];
@@ -204,7 +213,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
204213 delete[] token_lens;
205214
206215 if (tokenizer == nullptr ) {
207- LLAMA_LOG_ERROR (" llg tokenizer error: %s\n " , error_buffer);
216+ LOG_ERR (" llg tokenizer error: %s\n " , error_buffer);
208217 return tokenizer;
209218 }
210219
@@ -221,10 +230,13 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
221230 const char * grammar_kind, const char * grammar_data) {
222231 auto * ctx = new llama_sampler_llg;
223232
233+ const llama_vocab * vocab = llama_model_get_vocab (model);
234+
224235 if (grammar_kind != nullptr && grammar_kind[0 ] != ' \0 ' ) {
225236 auto tokenizer = llama_sampler_llg_new_tokenizer (model);
226237 *ctx = {
227238 /* .model = */ model,
239+ /* .vocab = */ vocab,
228240 /* .grammar_kind = */ grammar_kind,
229241 /* .grammar_data = */ grammar_data,
230242 /* .tokenizer = */ tokenizer,
@@ -235,6 +247,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
235247 } else {
236248 *ctx = {
237249 /* .model = */ model,
250+ /* .vocab = */ vocab,
238251 /* .grammar_kind = */ {},
239252 /* .grammar_data = */ {},
240253 /* .tokenizer = */ nullptr ,
0 commit comments