Skip to content

Commit 1f65f7a

Browse files
committed
first draft of llguidance sampler
1 parent 5c333e0 commit 1f65f7a

File tree

5 files changed

+167
-0
lines changed

5 files changed

+167
-0
lines changed

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ LEGACY_TARGETS_CLEAN = main quantize quantize-stats perplexity imatrix embedding
7373
# We don't want to clutter things too much, so we only build replacements for the most commonly used binaries.
7474
LEGACY_TARGETS_BUILD = main quantize perplexity embedding server
7575

76+
GGML_LLGUIDANCE := 1
77+
7678
# Deprecation aliases
7779
ifdef LLAMA_CUBLAS
7880
$(error LLAMA_CUBLAS is removed. Use GGML_CUDA instead.)
@@ -359,6 +361,11 @@ ifdef LLAMA_SERVER_SSL
359361
MK_LDFLAGS += -lssl -lcrypto
360362
endif
361363

364+
ifdef GGML_LLGUIDANCE
365+
MK_CPPFLAGS += -DGGML_LLGUIDANCE -I$(CURDIR)/../guidance-ws/llguidance/parser
366+
MK_LDFLAGS += -L$(CURDIR)/../guidance-ws/target/release -lllguidance_parser
367+
endif
368+
362369
# warnings
363370
WARN_FLAGS = \
364371
-Wall \

include/llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,13 @@ extern "C" {
11251125
const char * grammar_str,
11261126
const char * grammar_root);
11271127

1128+
#ifdef GGML_LLGUIDANCE
1129+
LLAMA_API struct llama_sampler * llama_sampler_init_llg(
1130+
const struct llama_model * model,
1131+
const char * grammar_type,
1132+
const char * grammar_data);
1133+
#endif
1134+
11281135
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
11291136
int32_t n_vocab, // llama_n_vocab()
11301137
llama_token special_eos_id, // llama_token_eos()

src/llama-sampling.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,3 +2342,145 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
23422342

23432343
ctx->t_sample_us = ctx->n_sample = 0;
23442344
}
2345+
2346+
#ifdef GGML_LLGUIDANCE
2347+
#include "llguidance.h"
2348+
2349+
struct llama_sampler_llg {
2350+
const struct llama_vocab * vocab;
2351+
std::string grammar_kind;
2352+
std::string grammar_data;
2353+
LlgConstraint *grammar;
2354+
LlgMaskResult llg_res;
2355+
bool has_llg_res;
2356+
};
2357+
2358+
static LlgConstraint *llama_sampler_llg_new(const char * grammar_kind, const char * grammar_data) {
2359+
LlgConstraintInit cinit;
2360+
llg_constraint_init_set_defaults(&cinit, nullptr);
2361+
return llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
2362+
}
2363+
2364+
static const char * llama_sampler_llg_name(const struct llama_sampler * /*smpl*/) {
2365+
return "llguidance";
2366+
}
2367+
2368+
static void llama_sampler_llg_accept_impl(struct llama_sampler * smpl, llama_token token) {
2369+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
2370+
if (ctx->grammar) {
2371+
LlgCommitResult res;
2372+
llg_commit_token(ctx->grammar, token, &res);
2373+
ctx->has_llg_res = false;
2374+
}
2375+
}
2376+
2377+
static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2378+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
2379+
if (ctx->grammar) {
2380+
if (!ctx->has_llg_res) {
2381+
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
2382+
ctx->has_llg_res = true;
2383+
} else {
2384+
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar));
2385+
}
2386+
}
2387+
if (ctx->has_llg_res) {
2388+
if (ctx->llg_res.is_stop) {
2389+
for (size_t i = 0; i < cur_p->size; ++i) {
2390+
if (!llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2391+
cur_p->data[i].logit = -INFINITY;
2392+
}
2393+
}
2394+
} else {
2395+
const uint32_t *mask = ctx->llg_res.sample_mask;
2396+
for (size_t i = 0; i < cur_p->size; ++i) {
2397+
auto token = cur_p->data[i].id;
2398+
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
2399+
cur_p->data[i].logit = -INFINITY;
2400+
}
2401+
}
2402+
}
2403+
}
2404+
}
2405+
}
2406+
2407+
static void llama_sampler_llg_reset(struct llama_sampler * smpl) {
2408+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
2409+
if (!ctx->grammar) {
2410+
return;
2411+
}
2412+
2413+
auto * grammar_new = llama_sampler_llg_new(ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
2414+
llg_free_constraint(ctx->grammar);
2415+
ctx->grammar = grammar_new;
2416+
ctx->has_llg_res = false;
2417+
}
2418+
2419+
static struct llama_sampler * llama_sampler_llg_clone(const struct llama_sampler * smpl) {
2420+
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
2421+
2422+
auto * result = llama_sampler_init_llg_impl(*ctx->vocab, nullptr, nullptr);
2423+
2424+
// copy the state
2425+
{
2426+
auto * result_ctx = (llama_sampler_llg *) result->ctx;
2427+
2428+
if (ctx->grammar) {
2429+
result_ctx->grammar_kind = ctx->grammar_kind;
2430+
result_ctx->grammar_data = ctx->grammar_data;
2431+
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
2432+
}
2433+
}
2434+
2435+
return result;
2436+
}
2437+
2438+
static void llama_sampler_llg_free(struct llama_sampler * smpl) {
2439+
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
2440+
2441+
if (ctx->grammar) {
2442+
llg_free_constraint(ctx->grammar);
2443+
}
2444+
2445+
delete ctx;
2446+
}
2447+
2448+
static struct llama_sampler_i llama_sampler_llg_i = {
2449+
/* .name = */ llama_sampler_llg_name,
2450+
/* .accept = */ llama_sampler_llg_accept_impl,
2451+
/* .apply = */ llama_sampler_llg_apply,
2452+
/* .reset = */ llama_sampler_llg_reset,
2453+
/* .clone = */ llama_sampler_llg_clone,
2454+
/* .free = */ llama_sampler_llg_free,
2455+
};
2456+
2457+
struct llama_sampler * llama_sampler_init_llg_impl(const struct llama_vocab & vocab, const char * grammar_kind, const char * grammar_data) {
2458+
auto * ctx = new llama_sampler_llg;
2459+
2460+
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
2461+
*ctx = {
2462+
/* .vocab = */ &vocab,
2463+
/* .grammar_kind = */ grammar_kind,
2464+
/* .grammar_data = */ grammar_data,
2465+
/* .grammar = */ llama_sampler_llg_new(grammar_kind, grammar_data),
2466+
/* .llg_res = */ {},
2467+
/* .has_llg_res = */ false,
2468+
};
2469+
} else {
2470+
*ctx = {
2471+
/* .vocab = */ &vocab,
2472+
/* .grammar_kind = */ {},
2473+
/* .grammar_data = */ {},
2474+
/* .grammar = */ nullptr,
2475+
/* .llg_res = */ {},
2476+
/* .has_llg_res = */ false,
2477+
};
2478+
}
2479+
2480+
return new llama_sampler {
2481+
/* .iface = */ &llama_sampler_llg_i,
2482+
/* .ctx = */ ctx,
2483+
};
2484+
}
2485+
2486+
#endif

src/llama-sampling.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
2525
const struct llama_vocab & vocab,
2626
const char * grammar_str,
2727
const char * grammar_root);
28+
struct llama_sampler * llama_sampler_init_llg_impl(
29+
const struct llama_vocab & vocab,
30+
const char * grammar_type,
31+
const char * grammar_data);
32+
2833

2934
struct llama_sampler * llama_sampler_init_infill_impl(
3035
const struct llama_vocab & vocab);

src/llama.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21866,6 +21866,12 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
2186621866
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
2186721867
}
2186821868

21869+
#ifdef GGML_LLGUIDANCE
21870+
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model, const char * grammar_type, const char * grammar_data) {
21871+
return llama_sampler_init_llg_impl(model->vocab, grammar_type, grammar_data);
21872+
}
21873+
#endif
21874+
2186921875
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
2187021876
return llama_sampler_init_infill_impl(model->vocab);
2187121877
}

0 commit comments

Comments
 (0)