@@ -11,25 +11,24 @@ struct llama_sampler_llg {
11
11
std::string grammar_kind;
12
12
std::string grammar_data;
13
13
LlgTokenizer * tokenizer;
14
- LlgConstraint * grammar;
15
- LlgMaskResult llg_res;
16
- bool has_llg_res;
14
+ LlgMatcher * grammar;
17
15
};
18
16
19
- static LlgConstraint * llama_sampler_llg_new (LlgTokenizer * tokenizer, const char * grammar_kind,
20
- const char * grammar_data) {
17
+ static LlgMatcher * llama_sampler_llg_new (LlgTokenizer * tokenizer, const char * grammar_kind,
18
+ const char * grammar_data) {
21
19
LlgConstraintInit cinit;
22
20
llg_constraint_init_set_defaults (&cinit, tokenizer);
23
21
const char * log_level = getenv (" LLGUIDANCE_LOG_LEVEL" );
24
22
if (log_level && *log_level) {
25
23
cinit.log_stderr_level = atoi (log_level);
26
24
}
27
- auto c = llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
28
- if (llg_get_error (c)) {
29
- LOG_ERR (" llg error: %s\n " , llg_get_error (c));
30
- llg_free_constraint (c);
25
+ auto c = llg_new_matcher (&cinit, grammar_kind, grammar_data);
26
+ if (llg_matcher_get_error (c)) {
27
+ LOG_ERR (" llg error: %s\n " , llg_matcher_get_error (c));
28
+ llg_free_matcher (c);
31
29
return nullptr ;
32
30
}
31
+
33
32
return c;
34
33
}
35
34
@@ -40,54 +39,39 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
40
39
static void llama_sampler_llg_accept_impl (llama_sampler * smpl, llama_token token) {
41
40
auto * ctx = (llama_sampler_llg *) smpl->ctx ;
42
41
if (ctx->grammar ) {
43
- LlgCommitResult res;
44
- llg_commit_token (ctx->grammar , token, &res);
45
- ctx->has_llg_res = false ;
42
+ llg_matcher_consume_token (ctx->grammar , token);
46
43
}
47
44
}
48
45
49
46
static void llama_sampler_llg_apply (llama_sampler * smpl, llama_token_data_array * cur_p) {
50
47
auto * ctx = (llama_sampler_llg *) smpl->ctx ;
51
48
if (ctx->grammar ) {
52
- if (!ctx->has_llg_res ) {
53
- if (llg_compute_mask (ctx->grammar , &ctx->llg_res ) == 0 ) {
54
- ctx->has_llg_res = true ;
49
+ const uint32_t * mask = llg_matcher_get_mask (ctx->grammar );
50
+ if (mask == nullptr ) {
51
+ if (llg_matcher_compute_mask (ctx->grammar ) == 0 ) {
52
+ mask = llg_matcher_get_mask (ctx->grammar );
55
53
} else {
56
- LOG_ERR (" llg error: %s\n " , llg_get_error (ctx->grammar ));
57
- llg_free_constraint (ctx->grammar );
54
+ LOG_ERR (" llg error: %s\n " , llg_matcher_get_error (ctx->grammar ));
55
+ llg_free_matcher (ctx->grammar );
58
56
ctx->grammar = nullptr ;
57
+ return ;
59
58
}
60
59
}
61
- if (ctx->has_llg_res ) {
62
- if (ctx->llg_res .is_stop ) {
63
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
64
- if (!llama_vocab_is_eog (ctx->vocab , cur_p->data [i].id )) {
65
- cur_p->data [i].logit = -INFINITY;
66
- }
67
- }
68
- } else {
69
- const uint32_t * mask = ctx->llg_res .sample_mask ;
70
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
71
- auto token = cur_p->data [i].id ;
72
- if ((mask[token / 32 ] & (1 << (token % 32 ))) == 0 ) {
73
- cur_p->data [i].logit = -INFINITY;
74
- }
75
- }
60
+
61
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
62
+ auto token = cur_p->data [i].id ;
63
+ if ((mask[token / 32 ] & (1 << (token % 32 ))) == 0 ) {
64
+ cur_p->data [i].logit = -INFINITY;
76
65
}
77
66
}
78
67
}
79
68
}
80
69
81
70
static void llama_sampler_llg_reset (llama_sampler * smpl) {
82
71
auto * ctx = (llama_sampler_llg *) smpl->ctx ;
83
- if (! ctx->grammar ) {
84
- return ;
72
+ if (ctx->grammar ) {
73
+ llg_matcher_reset (ctx-> grammar ) ;
85
74
}
86
-
87
- auto * grammar_new = llama_sampler_llg_new (ctx->tokenizer , ctx->grammar_kind .c_str (), ctx->grammar_data .c_str ());
88
- llg_free_constraint (ctx->grammar );
89
- ctx->grammar = grammar_new;
90
- ctx->has_llg_res = false ;
91
75
}
92
76
93
77
static llama_sampler * llama_sampler_llg_clone (const llama_sampler * smpl) {
@@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
102
86
if (ctx->grammar ) {
103
87
result_ctx->grammar_kind = ctx->grammar_kind ;
104
88
result_ctx->grammar_data = ctx->grammar_data ;
105
- result_ctx->grammar = llg_clone_constraint (ctx->grammar );
89
+ result_ctx->grammar = llg_clone_matcher (ctx->grammar );
106
90
result_ctx->tokenizer = llg_clone_tokenizer (ctx->tokenizer );
107
91
}
108
92
}
@@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
114
98
const auto * ctx = (llama_sampler_llg *) smpl->ctx ;
115
99
116
100
if (ctx->grammar ) {
117
- llg_free_constraint (ctx->grammar );
101
+ llg_free_matcher (ctx->grammar );
118
102
llg_free_tokenizer (ctx->tokenizer );
119
103
}
120
104
@@ -239,25 +223,24 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
239
223
/* .grammar_data = */ grammar_data,
240
224
/* .tokenizer = */ tokenizer,
241
225
/* .grammar = */ llama_sampler_llg_new (tokenizer, grammar_kind, grammar_data),
242
- /* .llg_res = */ {},
243
- /* .has_llg_res = */ false ,
244
226
};
227
+ if (ctx->grammar ) {
228
+ GGML_ASSERT (((size_t ) llama_vocab_n_tokens (vocab) + 31 ) / 32 * 4 ==
229
+ llg_matcher_get_mask_byte_size (ctx->grammar ));
230
+ }
245
231
} else {
246
232
*ctx = {
247
233
/* .vocab = */ vocab,
248
234
/* .grammar_kind = */ {},
249
235
/* .grammar_data = */ {},
250
236
/* .tokenizer = */ nullptr ,
251
237
/* .grammar = */ nullptr ,
252
- /* .llg_res = */ {},
253
- /* .has_llg_res = */ false ,
254
238
};
255
239
}
256
240
257
241
return llama_sampler_init (
258
242
/* .iface = */ &llama_sampler_llg_i,
259
- /* .ctx = */ ctx
260
- );
243
+ /* .ctx = */ ctx);
261
244
}
262
245
263
246
#else
0 commit comments