@@ -11,25 +11,24 @@ struct llama_sampler_llg {
1111 std::string grammar_kind;
1212 std::string grammar_data;
1313 LlgTokenizer * tokenizer;
14- LlgConstraint * grammar;
15- LlgMaskResult llg_res;
16- bool has_llg_res;
14+ LlgMatcher * grammar;
1715};
1816
19- static LlgConstraint * llama_sampler_llg_new (LlgTokenizer * tokenizer, const char * grammar_kind,
20- const char * grammar_data) {
17+ static LlgMatcher * llama_sampler_llg_new (LlgTokenizer * tokenizer, const char * grammar_kind,
18+ const char * grammar_data) {
2119 LlgConstraintInit cinit;
2220 llg_constraint_init_set_defaults (&cinit, tokenizer);
2321 const char * log_level = getenv (" LLGUIDANCE_LOG_LEVEL" );
2422 if (log_level && *log_level) {
2523 cinit.log_stderr_level = atoi (log_level);
2624 }
27- auto c = llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
28- if (llg_get_error (c)) {
29- LOG_ERR (" llg error: %s\n " , llg_get_error (c));
30- llg_free_constraint (c);
25+ auto c = llg_new_matcher (&cinit, grammar_kind, grammar_data);
26+ if (llg_matcher_get_error (c)) {
27+ LOG_ERR (" llg error: %s\n " , llg_matcher_get_error (c));
28+ llg_free_matcher (c);
3129 return nullptr ;
3230 }
31+
3332 return c;
3433}
3534
@@ -40,54 +39,39 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
4039static void llama_sampler_llg_accept_impl (llama_sampler * smpl, llama_token token) {
4140 auto * ctx = (llama_sampler_llg *) smpl->ctx ;
4241 if (ctx->grammar ) {
43- LlgCommitResult res;
44- llg_commit_token (ctx->grammar , token, &res);
45- ctx->has_llg_res = false ;
42+ llg_matcher_consume_token (ctx->grammar , token);
4643 }
4744}
4845
4946static void llama_sampler_llg_apply (llama_sampler * smpl, llama_token_data_array * cur_p) {
5047 auto * ctx = (llama_sampler_llg *) smpl->ctx ;
5148 if (ctx->grammar ) {
52- if (!ctx->has_llg_res ) {
53- if (llg_compute_mask (ctx->grammar , &ctx->llg_res ) == 0 ) {
54- ctx->has_llg_res = true ;
49+ const uint32_t * mask = llg_matcher_get_mask (ctx->grammar );
50+ if (mask == nullptr ) {
51+ if (llg_matcher_compute_mask (ctx->grammar ) == 0 ) {
52+ mask = llg_matcher_get_mask (ctx->grammar );
5553 } else {
56- LOG_ERR (" llg error: %s\n " , llg_get_error (ctx->grammar ));
57- llg_free_constraint (ctx->grammar );
54+ LOG_ERR (" llg error: %s\n " , llg_matcher_get_error (ctx->grammar ));
55+ llg_free_matcher (ctx->grammar );
5856 ctx->grammar = nullptr ;
57+ return ;
5958 }
6059 }
61- if (ctx->has_llg_res ) {
62- if (ctx->llg_res .is_stop ) {
63- for (size_t i = 0 ; i < cur_p->size ; ++i) {
64- if (!llama_vocab_is_eog (ctx->vocab , cur_p->data [i].id )) {
65- cur_p->data [i].logit = -INFINITY;
66- }
67- }
68- } else {
69- const uint32_t * mask = ctx->llg_res .sample_mask ;
70- for (size_t i = 0 ; i < cur_p->size ; ++i) {
71- auto token = cur_p->data [i].id ;
72- if ((mask[token / 32 ] & (1 << (token % 32 ))) == 0 ) {
73- cur_p->data [i].logit = -INFINITY;
74- }
75- }
60+
61+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
62+ auto token = cur_p->data [i].id ;
63+ if ((mask[token / 32 ] & (1 << (token % 32 ))) == 0 ) {
64+ cur_p->data [i].logit = -INFINITY;
7665 }
7766 }
7867 }
7968}
8069
8170static void llama_sampler_llg_reset (llama_sampler * smpl) {
8271 auto * ctx = (llama_sampler_llg *) smpl->ctx ;
83- if (! ctx->grammar ) {
84- return ;
72+ if (ctx->grammar ) {
73+ llg_matcher_reset (ctx-> grammar ) ;
8574 }
86-
87- auto * grammar_new = llama_sampler_llg_new (ctx->tokenizer , ctx->grammar_kind .c_str (), ctx->grammar_data .c_str ());
88- llg_free_constraint (ctx->grammar );
89- ctx->grammar = grammar_new;
90- ctx->has_llg_res = false ;
9175}
9276
9377static llama_sampler * llama_sampler_llg_clone (const llama_sampler * smpl) {
@@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
10286 if (ctx->grammar ) {
10387 result_ctx->grammar_kind = ctx->grammar_kind ;
10488 result_ctx->grammar_data = ctx->grammar_data ;
105- result_ctx->grammar = llg_clone_constraint (ctx->grammar );
89+ result_ctx->grammar = llg_clone_matcher (ctx->grammar );
10690 result_ctx->tokenizer = llg_clone_tokenizer (ctx->tokenizer );
10791 }
10892 }
@@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
11498 const auto * ctx = (llama_sampler_llg *) smpl->ctx ;
11599
116100 if (ctx->grammar ) {
117- llg_free_constraint (ctx->grammar );
101+ llg_free_matcher (ctx->grammar );
118102 llg_free_tokenizer (ctx->tokenizer );
119103 }
120104
@@ -239,25 +223,24 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
239223 /* .grammar_data = */ grammar_data,
240224 /* .tokenizer = */ tokenizer,
241225 /* .grammar = */ llama_sampler_llg_new (tokenizer, grammar_kind, grammar_data),
242- /* .llg_res = */ {},
243- /* .has_llg_res = */ false ,
244226 };
227+ if (ctx->grammar ) {
228+ GGML_ASSERT (((size_t ) llama_vocab_n_tokens (vocab) + 31 ) / 32 * 4 ==
229+ llg_matcher_get_mask_byte_size (ctx->grammar ));
230+ }
245231 } else {
246232 *ctx = {
247233 /* .vocab = */ vocab,
248234 /* .grammar_kind = */ {},
249235 /* .grammar_data = */ {},
250236 /* .tokenizer = */ nullptr ,
251237 /* .grammar = */ nullptr ,
252- /* .llg_res = */ {},
253- /* .has_llg_res = */ false ,
254238 };
255239 }
256240
257241 return llama_sampler_init (
258242 /* .iface = */ &llama_sampler_llg_i,
259- /* .ctx = */ ctx
260- );
243+ /* .ctx = */ ctx);
261244}
262245
263246#else
0 commit comments