Skip to content

Commit 76290d9

Browse files
committed
initial porting of previous LLG patch
1 parent 4a75d19 commit 76290d9

File tree

5 files changed

+277
-1
lines changed

5 files changed

+277
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,4 @@ poetry.toml
143143
# Local scripts
144144
/run-vim.sh
145145
/run-chat.sh
146+
include/llguidance.h

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
7979

8080
# 3rd party libs
8181
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
82+
option(LLAMA_LLGUIDANCE "llama: build LLGuidance library for structured output" OFF)
8283

8384
# Required for relocatable CMake package
8485
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)

common/json-schema-to-grammar.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,11 +991,15 @@ class SchemaConverter {
991991
};
992992

993993
std::string json_schema_to_grammar(const json & schema) {
994+
#ifdef LLAMA_LLGUIDANCE
995+
return "llg:json:" + schema.dump();
996+
#else
994997
return build_grammar([&](const llama_grammar_builder & callbacks) {
995998
auto copy = schema;
996999
callbacks.resolve_refs(copy);
9971000
callbacks.add_schema("", copy);
9981001
});
1002+
#endif
9991003
}
10001004

10011005
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {

common/llguidance.cpp

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
#ifdef LLAMA_LLGUIDANCE
2+
#include "llguidance.h"
3+
4+
struct llama_sampler_llg {
5+
const struct llama_model * model;
6+
std::string grammar_kind;
7+
std::string grammar_data;
8+
LlgTokenizer *tokenizer;
9+
LlgConstraint *grammar;
10+
LlgMaskResult llg_res;
11+
bool has_llg_res;
12+
};
13+
14+
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
15+
const char * grammar_kind, const char * grammar_data) {
16+
LlgConstraintInit cinit;
17+
llg_constraint_init_set_defaults(&cinit, tokenizer);
18+
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
19+
if (llg_get_error(c)) {
20+
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c));
21+
llg_free_constraint(c);
22+
return nullptr;
23+
}
24+
return c;
25+
}
26+
27+
static const char * llama_sampler_llg_name(const struct llama_sampler * /*smpl*/) {
28+
return "llguidance";
29+
}
30+
31+
static void llama_sampler_llg_accept_impl(struct llama_sampler * smpl, llama_token token) {
32+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
33+
if (ctx->grammar) {
34+
LlgCommitResult res;
35+
llg_commit_token(ctx->grammar, token, &res);
36+
ctx->has_llg_res = false;
37+
}
38+
}
39+
40+
static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
41+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
42+
if (ctx->grammar) {
43+
if (!ctx->has_llg_res) {
44+
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
45+
ctx->has_llg_res = true;
46+
} else {
47+
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar));
48+
llg_free_constraint(ctx->grammar);
49+
ctx->grammar = nullptr;
50+
}
51+
}
52+
if (ctx->has_llg_res) {
53+
if (ctx->llg_res.is_stop) {
54+
for (size_t i = 0; i < cur_p->size; ++i) {
55+
if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) {
56+
cur_p->data[i].logit = -INFINITY;
57+
}
58+
}
59+
} else {
60+
const uint32_t *mask = ctx->llg_res.sample_mask;
61+
for (size_t i = 0; i < cur_p->size; ++i) {
62+
auto token = cur_p->data[i].id;
63+
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
64+
cur_p->data[i].logit = -INFINITY;
65+
}
66+
}
67+
}
68+
}
69+
}
70+
}
71+
72+
static void llama_sampler_llg_reset(struct llama_sampler * smpl) {
73+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
74+
if (!ctx->grammar) {
75+
return;
76+
}
77+
78+
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
79+
llg_free_constraint(ctx->grammar);
80+
ctx->grammar = grammar_new;
81+
ctx->has_llg_res = false;
82+
}
83+
84+
static struct llama_sampler * llama_sampler_llg_clone(const struct llama_sampler * smpl) {
85+
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
86+
87+
auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr);
88+
89+
// copy the state
90+
{
91+
auto * result_ctx = (llama_sampler_llg *) result->ctx;
92+
93+
if (ctx->grammar) {
94+
result_ctx->grammar_kind = ctx->grammar_kind;
95+
result_ctx->grammar_data = ctx->grammar_data;
96+
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
97+
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
98+
}
99+
}
100+
101+
return result;
102+
}
103+
104+
static void llama_sampler_llg_free(struct llama_sampler * smpl) {
105+
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
106+
107+
if (ctx->grammar) {
108+
llg_free_constraint(ctx->grammar);
109+
llg_free_tokenizer(ctx->tokenizer);
110+
}
111+
112+
delete ctx;
113+
}
114+
115+
static struct llama_sampler_i llama_sampler_llg_i = {
116+
/* .name = */ llama_sampler_llg_name,
117+
/* .accept = */ llama_sampler_llg_accept_impl,
118+
/* .apply = */ llama_sampler_llg_apply,
119+
/* .reset = */ llama_sampler_llg_reset,
120+
/* .clone = */ llama_sampler_llg_clone,
121+
/* .free = */ llama_sampler_llg_free,
122+
};
123+
124+
125+
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
126+
const uint8_t *bytes,
127+
size_t bytes_len,
128+
uint32_t *output_tokens,
129+
size_t output_tokens_len)
130+
{
131+
const struct llama_model *model = (const struct llama_model *)user_data;
132+
int r = llama_tokenize(model, (const char *) bytes, bytes_len,
133+
(int32_t*)output_tokens, output_tokens_len, false, true);
134+
if (r < 0)
135+
return -r;
136+
return r;
137+
}
138+
139+
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * model) {
140+
// TODO store the tokenizer in the model somehow
141+
static const struct llama_model *model_cache;
142+
static LlgTokenizer *tokenizer_cache;
143+
144+
if (model_cache == model) {
145+
return llg_clone_tokenizer(tokenizer_cache);
146+
}
147+
148+
auto tok_eos = llama_token_eot(model);
149+
if (tok_eos == LLAMA_TOKEN_NULL)
150+
tok_eos = llama_token_eos(model);
151+
152+
size_t vocab_size = llama_n_vocab(model);
153+
154+
auto token_lens = new uint32_t[vocab_size];
155+
// we typically have ~7 bytes per token; let's go on the safe side here
156+
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
157+
auto token_bytes = new uint8_t[token_bytes_size];
158+
159+
size_t offset = 0;
160+
for (size_t i = 0; i < vocab_size; i++) {
161+
size_t max_token = 1024;
162+
if (token_bytes_size - offset < max_token) {
163+
GGML_ABORT("token_bytes buffer too small\n");
164+
}
165+
166+
llama_token token = i;
167+
auto dp = (char *) token_bytes + offset;
168+
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false);
169+
if (size < 0) {
170+
GGML_ABORT("llama_detokenize failed\n");
171+
}
172+
if (size == 0) {
173+
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true);
174+
if (size < 0) {
175+
GGML_ABORT("llama_detokenize failed\n");
176+
}
177+
if (size != 0) {
178+
*dp = '\xff'; // special token prefix marker
179+
size += 1;
180+
}
181+
}
182+
183+
token_lens[i] = size;
184+
offset += size;
185+
}
186+
187+
188+
LlgTokenizerInit tinit = {
189+
/* .vocab_size = */ (uint32_t)vocab_size,
190+
/* .tok_eos = */ (uint32_t)tok_eos,
191+
/* .token_lens = */ token_lens,
192+
/* .token_bytes = */ token_bytes,
193+
/* .tokenizer_json = */ nullptr,
194+
/* .tokenize_assumes_string = */ false,
195+
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
196+
/* .use_approximate_greedy_tokenize_fn = */ false,
197+
/* .tokenize_user_data = */ model,
198+
};
199+
200+
char error_buffer[1024];
201+
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
202+
203+
delete[] token_bytes;
204+
delete[] token_lens;
205+
206+
if (tokenizer == nullptr) {
207+
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer);
208+
return tokenizer;
209+
}
210+
211+
if (tokenizer_cache) {
212+
llg_free_tokenizer(tokenizer_cache);
213+
}
214+
model_cache = model;
215+
tokenizer_cache = tokenizer;
216+
217+
return tokenizer;
218+
}
219+
220+
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
221+
const char * grammar_kind, const char * grammar_data) {
222+
auto * ctx = new llama_sampler_llg;
223+
224+
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
225+
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
226+
*ctx = {
227+
/* .model = */ model,
228+
/* .grammar_kind = */ grammar_kind,
229+
/* .grammar_data = */ grammar_data,
230+
/* .tokenizer = */ tokenizer,
231+
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
232+
/* .llg_res = */ {},
233+
/* .has_llg_res = */ false,
234+
};
235+
} else {
236+
*ctx = {
237+
/* .model = */ model,
238+
/* .grammar_kind = */ {},
239+
/* .grammar_data = */ {},
240+
/* .tokenizer = */ nullptr,
241+
/* .grammar = */ nullptr,
242+
/* .llg_res = */ {},
243+
/* .has_llg_res = */ false,
244+
};
245+
}
246+
247+
return new llama_sampler {
248+
/* .iface = */ &llama_sampler_llg_i,
249+
/* .ctx = */ ctx,
250+
};
251+
}
252+
253+
#endif

common/sampling.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,26 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
151151

152152
lparams.no_perf = params.no_perf;
153153

154+
struct llama_sampler * grmr;
155+
if (params.grammar.compare(0, 4, "llg:") == 0) {
156+
#ifdef LLAMA_LLGUIDANCE
157+
auto gp = params.grammar.find(':', 4);
158+
if (gp == std::string::npos) {
159+
GGML_ABORT("invalid serialized grammar");
160+
}
161+
auto grm_type = params.grammar.substr(4, gp - 4);
162+
auto grm_data = params.grammar.c_str() + gp + 1;
163+
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
164+
#else
165+
GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled");
166+
#endif
167+
} else {
168+
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
169+
}
170+
154171
auto * result = new common_sampler {
155172
/* .params = */ params,
156-
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
173+
/* .grmr = */ grmr,
157174
/* .chain = */ llama_sampler_chain_init(lparams),
158175
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
159176
/* .cur = */ {},

0 commit comments

Comments
 (0)