Skip to content

Commit 2447ad8

Browse files
authored
upgrade to llguidance 0.7.10 (#12576)
1 parent 02082f1 commit 2447ad8

File tree

3 files changed

+94
-49
lines changed

3 files changed

+94
-49
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE)
114114

115115
ExternalProject_Add(llguidance_ext
116116
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
117-
# v0.6.12:
118-
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
117+
# v0.7.10:
118+
GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8
119119
PREFIX ${CMAKE_BINARY_DIR}/llguidance
120120
SOURCE_DIR ${LLGUIDANCE_SRC}
121121
BUILD_IN_SOURCE TRUE

common/llguidance.cpp

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,24 @@ struct llama_sampler_llg {
1111
std::string grammar_kind;
1212
std::string grammar_data;
1313
LlgTokenizer * tokenizer;
14-
LlgConstraint * grammar;
15-
LlgMaskResult llg_res;
16-
bool has_llg_res;
14+
LlgMatcher * grammar;
1715
};
1816

19-
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
20-
const char * grammar_data) {
17+
static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
18+
const char * grammar_data) {
2119
LlgConstraintInit cinit;
2220
llg_constraint_init_set_defaults(&cinit, tokenizer);
2321
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
2422
if (log_level && *log_level) {
2523
cinit.log_stderr_level = atoi(log_level);
2624
}
27-
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
28-
if (llg_get_error(c)) {
29-
LOG_ERR("llg error: %s\n", llg_get_error(c));
30-
llg_free_constraint(c);
25+
auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
26+
if (llg_matcher_get_error(c)) {
27+
LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
28+
llg_free_matcher(c);
3129
return nullptr;
3230
}
31+
3332
return c;
3433
}
3534

@@ -40,54 +39,39 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
4039
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
4140
auto * ctx = (llama_sampler_llg *) smpl->ctx;
4241
if (ctx->grammar) {
43-
LlgCommitResult res;
44-
llg_commit_token(ctx->grammar, token, &res);
45-
ctx->has_llg_res = false;
42+
llg_matcher_consume_token(ctx->grammar, token);
4643
}
4744
}
4845

4946
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
5047
auto * ctx = (llama_sampler_llg *) smpl->ctx;
5148
if (ctx->grammar) {
52-
if (!ctx->has_llg_res) {
53-
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
54-
ctx->has_llg_res = true;
49+
const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
50+
if (mask == nullptr) {
51+
if (llg_matcher_compute_mask(ctx->grammar) == 0) {
52+
mask = llg_matcher_get_mask(ctx->grammar);
5553
} else {
56-
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
57-
llg_free_constraint(ctx->grammar);
54+
LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
55+
llg_free_matcher(ctx->grammar);
5856
ctx->grammar = nullptr;
57+
return;
5958
}
6059
}
61-
if (ctx->has_llg_res) {
62-
if (ctx->llg_res.is_stop) {
63-
for (size_t i = 0; i < cur_p->size; ++i) {
64-
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
65-
cur_p->data[i].logit = -INFINITY;
66-
}
67-
}
68-
} else {
69-
const uint32_t * mask = ctx->llg_res.sample_mask;
70-
for (size_t i = 0; i < cur_p->size; ++i) {
71-
auto token = cur_p->data[i].id;
72-
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
73-
cur_p->data[i].logit = -INFINITY;
74-
}
75-
}
60+
61+
for (size_t i = 0; i < cur_p->size; ++i) {
62+
auto token = cur_p->data[i].id;
63+
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
64+
cur_p->data[i].logit = -INFINITY;
7665
}
7766
}
7867
}
7968
}
8069

8170
static void llama_sampler_llg_reset(llama_sampler * smpl) {
8271
auto * ctx = (llama_sampler_llg *) smpl->ctx;
83-
if (!ctx->grammar) {
84-
return;
72+
if (ctx->grammar) {
73+
llg_matcher_reset(ctx->grammar);
8574
}
86-
87-
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
88-
llg_free_constraint(ctx->grammar);
89-
ctx->grammar = grammar_new;
90-
ctx->has_llg_res = false;
9175
}
9276

9377
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
@@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
10286
if (ctx->grammar) {
10387
result_ctx->grammar_kind = ctx->grammar_kind;
10488
result_ctx->grammar_data = ctx->grammar_data;
105-
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
89+
result_ctx->grammar = llg_clone_matcher(ctx->grammar);
10690
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
10791
}
10892
}
@@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
11498
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
11599

116100
if (ctx->grammar) {
117-
llg_free_constraint(ctx->grammar);
101+
llg_free_matcher(ctx->grammar);
118102
llg_free_tokenizer(ctx->tokenizer);
119103
}
120104

@@ -239,25 +223,24 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
239223
/* .grammar_data = */ grammar_data,
240224
/* .tokenizer = */ tokenizer,
241225
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
242-
/* .llg_res = */ {},
243-
/* .has_llg_res = */ false,
244226
};
227+
if (ctx->grammar) {
228+
GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
229+
llg_matcher_get_mask_byte_size(ctx->grammar));
230+
}
245231
} else {
246232
*ctx = {
247233
/* .vocab = */ vocab,
248234
/* .grammar_kind = */ {},
249235
/* .grammar_data = */ {},
250236
/* .tokenizer = */ nullptr,
251237
/* .grammar = */ nullptr,
252-
/* .llg_res = */ {},
253-
/* .has_llg_res = */ false,
254238
};
255239
}
256240

257241
return llama_sampler_init(
258242
/* .iface = */ &llama_sampler_llg_i,
259-
/* .ctx = */ ctx
260-
);
243+
/* .ctx = */ ctx);
261244
}
262245

263246
#else

tests/test-grammar-llguidance.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,65 @@ static void test_json_schema() {
10861086
});
10871087
}
10881088

1089+
static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
1090+
auto n_vocab = tok_arr.size;
1091+
1092+
tok_arr.selected = -1;
1093+
tok_arr.sorted = false;
1094+
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1095+
tok_arr.data[token_id].id = token_id;
1096+
tok_arr.data[token_id].logit = 0.0f;
1097+
}
1098+
1099+
tok_arr.data[selected].logit = 100.0f;
1100+
}
1101+
1102+
static void test_sampler_chain(void) {
1103+
auto sparams = llama_sampler_chain_default_params();
1104+
sparams.no_perf = false;
1105+
llama_sampler * sampler = llama_sampler_chain_init(sparams);
1106+
1107+
const auto grammar_data = R"(%llguidance {}
1108+
start: /[A-Z ]*/)";
1109+
1110+
llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
1111+
llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
1112+
1113+
auto input = "ALL YOUR BASE ARE BELONG TO US";
1114+
auto tokens = common_tokenize(vocab, input, false, false);
1115+
1116+
auto n_vocab = llama_vocab_n_tokens(vocab);
1117+
1118+
std::vector<llama_token_data> cur;
1119+
cur.reserve(n_vocab);
1120+
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1121+
cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
1122+
}
1123+
auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
1124+
1125+
for (const auto token : tokens) {
1126+
one_hot(tok_arr, token);
1127+
1128+
fprintf(stderr, "applying token: %d\n", token);
1129+
llama_sampler_apply(sampler, &tok_arr);
1130+
1131+
auto idx = tok_arr.selected;
1132+
fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
1133+
assert(cur[tok_arr.selected].id == token);
1134+
llama_sampler_accept(sampler, token);
1135+
}
1136+
1137+
auto tok_eos = llama_vocab_eot(vocab);
1138+
if (tok_eos == LLAMA_TOKEN_NULL) {
1139+
tok_eos = llama_vocab_eos(vocab);
1140+
}
1141+
1142+
one_hot(tok_arr, tok_eos);
1143+
1144+
llama_sampler_apply(sampler, &tok_arr);
1145+
assert(cur[tok_arr.selected].id == tok_eos);
1146+
}
1147+
10891148
int main(int argc, const char ** argv) {
10901149
fprintf(stdout, "Running llguidance integration tests...\n");
10911150

@@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
11351194
test_special_chars();
11361195
test_quantifiers();
11371196
test_json_schema();
1197+
1198+
test_sampler_chain();
1199+
11381200
fprintf(stdout, "All tests passed.\n");
11391201
return 0;
11401202
}

0 commit comments

Comments
 (0)