@@ -2342,3 +2342,145 @@ void llama_perf_sampler_reset(struct llama_sampler * chain) {
23422342
23432343 ctx->t_sample_us = ctx->n_sample = 0 ;
23442344}
2345+
2346+ #ifdef GGML_LLGUIDANCE
2347+ #include " llguidance.h"
2348+
2349+ struct llama_sampler_llg {
2350+ const struct llama_vocab * vocab;
2351+ std::string grammar_kind;
2352+ std::string grammar_data;
2353+ LlgConstraint *grammar;
2354+ LlgMaskResult llg_res;
2355+ bool has_llg_res;
2356+ };
2357+
2358+ static LlgConstraint *llama_sampler_llg_new (const char * grammar_kind, const char * grammar_data) {
2359+ LlgConstraintInit cinit;
2360+ llg_constraint_init_set_defaults (&cinit, nullptr );
2361+ return llg_new_constraint_any (&cinit, grammar_kind, grammar_data);
2362+ }
2363+
2364+ static const char * llama_sampler_llg_name (const struct llama_sampler * /* smpl*/ ) {
2365+ return " llguidance" ;
2366+ }
2367+
2368+ static void llama_sampler_llg_accept_impl (struct llama_sampler * smpl, llama_token token) {
2369+ auto * ctx = (llama_sampler_llg *) smpl->ctx ;
2370+ if (ctx->grammar ) {
2371+ LlgCommitResult res;
2372+ llg_commit_token (ctx->grammar , token, &res);
2373+ ctx->has_llg_res = false ;
2374+ }
2375+ }
2376+
2377+ static void llama_sampler_llg_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2378+ auto * ctx = (llama_sampler_llg *) smpl->ctx ;
2379+ if (ctx->grammar ) {
2380+ if (!ctx->has_llg_res ) {
2381+ if (llg_compute_mask (ctx->grammar , &ctx->llg_res ) == 0 ) {
2382+ ctx->has_llg_res = true ;
2383+ } else {
2384+ LLAMA_LOG_ERROR (" llg error: %s\n " , llg_get_error (ctx->grammar ));
2385+ }
2386+ }
2387+ if (ctx->has_llg_res ) {
2388+ if (ctx->llg_res .is_stop ) {
2389+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
2390+ if (!llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
2391+ cur_p->data [i].logit = -INFINITY;
2392+ }
2393+ }
2394+ } else {
2395+ const uint32_t *mask = ctx->llg_res .sample_mask ;
2396+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
2397+ auto token = cur_p->data [i].id ;
2398+ if ((mask[token / 32 ] & (1 << (token % 32 ))) == 0 ) {
2399+ cur_p->data [i].logit = -INFINITY;
2400+ }
2401+ }
2402+ }
2403+ }
2404+ }
2405+ }
2406+
2407+ static void llama_sampler_llg_reset (struct llama_sampler * smpl) {
2408+ auto * ctx = (llama_sampler_llg *) smpl->ctx ;
2409+ if (!ctx->grammar ) {
2410+ return ;
2411+ }
2412+
2413+ auto * grammar_new = llama_sampler_llg_new (ctx->grammar_kind .c_str (), ctx->grammar_data .c_str ());
2414+ llg_free_constraint (ctx->grammar );
2415+ ctx->grammar = grammar_new;
2416+ ctx->has_llg_res = false ;
2417+ }
2418+
2419+ static struct llama_sampler * llama_sampler_llg_clone (const struct llama_sampler * smpl) {
2420+ const auto * ctx = (const llama_sampler_llg *) smpl->ctx ;
2421+
2422+ auto * result = llama_sampler_init_llg_impl (*ctx->vocab , nullptr , nullptr );
2423+
2424+ // copy the state
2425+ {
2426+ auto * result_ctx = (llama_sampler_llg *) result->ctx ;
2427+
2428+ if (ctx->grammar ) {
2429+ result_ctx->grammar_kind = ctx->grammar_kind ;
2430+ result_ctx->grammar_data = ctx->grammar_data ;
2431+ result_ctx->grammar = llg_clone_constraint (ctx->grammar );
2432+ }
2433+ }
2434+
2435+ return result;
2436+ }
2437+
2438+ static void llama_sampler_llg_free (struct llama_sampler * smpl) {
2439+ const auto * ctx = (llama_sampler_llg *) smpl->ctx ;
2440+
2441+ if (ctx->grammar ) {
2442+ llg_free_constraint (ctx->grammar );
2443+ }
2444+
2445+ delete ctx;
2446+ }
2447+
2448+ static struct llama_sampler_i llama_sampler_llg_i = {
2449+ /* .name = */ llama_sampler_llg_name,
2450+ /* .accept = */ llama_sampler_llg_accept_impl,
2451+ /* .apply = */ llama_sampler_llg_apply,
2452+ /* .reset = */ llama_sampler_llg_reset,
2453+ /* .clone = */ llama_sampler_llg_clone,
2454+ /* .free = */ llama_sampler_llg_free,
2455+ };
2456+
2457+ struct llama_sampler * llama_sampler_init_llg_impl (const struct llama_vocab & vocab, const char * grammar_kind, const char * grammar_data) {
2458+ auto * ctx = new llama_sampler_llg;
2459+
2460+ if (grammar_kind != nullptr && grammar_kind[0 ] != ' \0 ' ) {
2461+ *ctx = {
2462+ /* .vocab = */ &vocab,
2463+ /* .grammar_kind = */ grammar_kind,
2464+ /* .grammar_data = */ grammar_data,
2465+ /* .grammar = */ llama_sampler_llg_new (grammar_kind, grammar_data),
2466+ /* .llg_res = */ {},
2467+ /* .has_llg_res = */ false ,
2468+ };
2469+ } else {
2470+ *ctx = {
2471+ /* .vocab = */ &vocab,
2472+ /* .grammar_kind = */ {},
2473+ /* .grammar_data = */ {},
2474+ /* .grammar = */ nullptr ,
2475+ /* .llg_res = */ {},
2476+ /* .has_llg_res = */ false ,
2477+ };
2478+ }
2479+
2480+ return new llama_sampler {
2481+ /* .iface = */ &llama_sampler_llg_i,
2482+ /* .ctx = */ ctx,
2483+ };
2484+ }
2485+
2486+ #endif
0 commit comments