@@ -11,25 +11,24 @@ struct llama_sampler_llg {
1111    std::string         grammar_kind;
1212    std::string         grammar_data;
1313    LlgTokenizer *      tokenizer;
14-     LlgConstraint *     grammar;
15-     LlgMaskResult       llg_res;
16-     bool                 has_llg_res;
14+     LlgMatcher *        grammar;
1715};
1816
19- static  LlgConstraint  * llama_sampler_llg_new (LlgTokenizer * tokenizer, const  char  * grammar_kind,
20-                                               const  char  * grammar_data) {
17+ static  LlgMatcher  * llama_sampler_llg_new (LlgTokenizer * tokenizer, const  char  * grammar_kind,
18+                                           const  char  * grammar_data) {
2119    LlgConstraintInit cinit;
2220    llg_constraint_init_set_defaults (&cinit, tokenizer);
2321    const  char  * log_level = getenv (" LLGUIDANCE_LOG_LEVEL"  );
2422    if  (log_level && *log_level) {
2523        cinit.log_stderr_level  = atoi (log_level);
2624    }
27-     auto  c = llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
28-     if  (llg_get_error (c)) {
29-         LOG_ERR (" llg error: %s\n "  , llg_get_error (c));
30-         llg_free_constraint (c);
25+     auto  c = llg_new_matcher (&cinit, grammar_kind, grammar_data);
26+     if  (llg_matcher_get_error (c)) {
27+         LOG_ERR (" llg error: %s\n "  , llg_matcher_get_error (c));
28+         llg_free_matcher (c);
3129        return  nullptr ;
3230    }
31+ 
3332    return  c;
3433}
3534
@@ -40,54 +39,39 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
4039static  void  llama_sampler_llg_accept_impl (llama_sampler * smpl, llama_token token) {
4140    auto  * ctx = (llama_sampler_llg *) smpl->ctx ;
4241    if  (ctx->grammar ) {
43-         LlgCommitResult res;
44-         llg_commit_token (ctx->grammar , token, &res);
45-         ctx->has_llg_res  = false ;
42+         llg_matcher_consume_token (ctx->grammar , token);
4643    }
4744}
4845
4946static  void  llama_sampler_llg_apply (llama_sampler * smpl, llama_token_data_array * cur_p) {
5047    auto  * ctx = (llama_sampler_llg *) smpl->ctx ;
5148    if  (ctx->grammar ) {
52-         if  (!ctx->has_llg_res ) {
53-             if  (llg_compute_mask (ctx->grammar , &ctx->llg_res ) == 0 ) {
54-                 ctx->has_llg_res  = true ;
49+         const  uint32_t  * mask = llg_matcher_get_mask (ctx->grammar );
50+         if  (mask == nullptr ) {
51+             if  (llg_matcher_compute_mask (ctx->grammar ) == 0 ) {
52+                 mask = llg_matcher_get_mask (ctx->grammar );
5553            } else  {
56-                 LOG_ERR (" llg error: %s\n "  , llg_get_error (ctx->grammar ));
57-                 llg_free_constraint (ctx->grammar );
54+                 LOG_ERR (" llg error: %s\n "  , llg_matcher_get_error (ctx->grammar ));
55+                 llg_free_matcher (ctx->grammar );
5856                ctx->grammar  = nullptr ;
57+                 return ;
5958            }
6059        }
61-         if  (ctx->has_llg_res ) {
62-             if  (ctx->llg_res .is_stop ) {
63-                 for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
64-                     if  (!llama_vocab_is_eog (ctx->vocab , cur_p->data [i].id )) {
65-                         cur_p->data [i].logit  = -INFINITY;
66-                     }
67-                 }
68-             } else  {
69-                 const  uint32_t  * mask = ctx->llg_res .sample_mask ;
70-                 for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
71-                     auto  token = cur_p->data [i].id ;
72-                     if  ((mask[token / 32 ] & (1  << (token % 32 ))) == 0 ) {
73-                         cur_p->data [i].logit  = -INFINITY;
74-                     }
75-                 }
60+ 
61+         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
62+             auto  token = cur_p->data [i].id ;
63+             if  ((mask[token / 32 ] & (1  << (token % 32 ))) == 0 ) {
64+                 cur_p->data [i].logit  = -INFINITY;
7665            }
7766        }
7867    }
7968}
8069
8170static  void  llama_sampler_llg_reset (llama_sampler * smpl) {
8271    auto  * ctx = (llama_sampler_llg *) smpl->ctx ;
83-     if  (! ctx->grammar ) {
84-         return ;
72+     if  (ctx->grammar ) {
73+         llg_matcher_reset (ctx-> grammar ) ;
8574    }
86- 
87-     auto  * grammar_new = llama_sampler_llg_new (ctx->tokenizer , ctx->grammar_kind .c_str (), ctx->grammar_data .c_str ());
88-     llg_free_constraint (ctx->grammar );
89-     ctx->grammar      = grammar_new;
90-     ctx->has_llg_res  = false ;
9175}
9276
9377static  llama_sampler * llama_sampler_llg_clone (const  llama_sampler * smpl) {
@@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
10286        if  (ctx->grammar ) {
10387            result_ctx->grammar_kind  = ctx->grammar_kind ;
10488            result_ctx->grammar_data  = ctx->grammar_data ;
105-             result_ctx->grammar       = llg_clone_constraint (ctx->grammar );
89+             result_ctx->grammar       = llg_clone_matcher (ctx->grammar );
10690            result_ctx->tokenizer     = llg_clone_tokenizer (ctx->tokenizer );
10791        }
10892    }
@@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
11498    const  auto  * ctx = (llama_sampler_llg *) smpl->ctx ;
11599
116100    if  (ctx->grammar ) {
117-         llg_free_constraint (ctx->grammar );
101+         llg_free_matcher (ctx->grammar );
118102        llg_free_tokenizer (ctx->tokenizer );
119103    }
120104
@@ -239,25 +223,24 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
239223            /*  .grammar_data = */   grammar_data,
240224            /*  .tokenizer    = */   tokenizer,
241225            /*  .grammar      = */   llama_sampler_llg_new (tokenizer, grammar_kind, grammar_data),
242-             /*  .llg_res      = */   {},
243-             /*  .has_llg_res  = */   false ,
244226        };
227+         if  (ctx->grammar ) {
228+             GGML_ASSERT (((size_t ) llama_vocab_n_tokens (vocab) + 31 ) / 32  * 4  ==
229+                         llg_matcher_get_mask_byte_size (ctx->grammar ));
230+         }
245231    } else  {
246232        *ctx = {
247233            /*  .vocab        = */   vocab,
248234            /*  .grammar_kind = */   {},
249235            /*  .grammar_data = */   {},
250236            /*  .tokenizer    = */   nullptr ,
251237            /*  .grammar      = */   nullptr ,
252-             /*  .llg_res      = */   {},
253-             /*  .has_llg_res  = */   false ,
254238        };
255239    }
256240
257241    return  llama_sampler_init (
258242        /*  .iface = */   &llama_sampler_llg_i,
259-         /*  .ctx   = */   ctx
260-     );
243+         /*  .ctx   = */   ctx);
261244}
262245
263246#else 
0 commit comments