Skip to content

Commit f19655c

Browse files
committed
update for new APIs
1 parent 76290d9 commit f19655c

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

common/json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ class SchemaConverter {
991991
};
992992

993993
std::string json_schema_to_grammar(const json & schema) {
994-
#ifdef LLAMA_LLGUIDANCE
994+
#ifdef LLAMA_USE_LLGUIDANCE
995995
return "llg:json:" + schema.dump();
996996
#else
997997
return build_grammar([&](const llama_grammar_builder & callbacks) {

common/llguidance.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
#ifdef LLAMA_LLGUIDANCE
1+
#ifdef LLAMA_USE_LLGUIDANCE
2+
3+
#include "common.h"
4+
#include "sampling.h"
5+
#include "log.h"
6+
#include "llama.h"
7+
28
#include "llguidance.h"
39

410
struct llama_sampler_llg {
511
const struct llama_model * model;
12+
const struct llama_vocab * vocab;
613
std::string grammar_kind;
714
std::string grammar_data;
815
LlgTokenizer *tokenizer;
@@ -17,7 +24,7 @@ static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
1724
llg_constraint_init_set_defaults(&cinit, tokenizer);
1825
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
1926
if (llg_get_error(c)) {
20-
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c));
27+
LOG_ERR("llg error: %s\n", llg_get_error(c));
2128
llg_free_constraint(c);
2229
return nullptr;
2330
}
@@ -44,15 +51,15 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
4451
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
4552
ctx->has_llg_res = true;
4653
} else {
47-
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar));
54+
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
4855
llg_free_constraint(ctx->grammar);
4956
ctx->grammar = nullptr;
5057
}
5158
}
5259
if (ctx->has_llg_res) {
5360
if (ctx->llg_res.is_stop) {
5461
for (size_t i = 0; i < cur_p->size; ++i) {
55-
if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) {
62+
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
5663
cur_p->data[i].logit = -INFINITY;
5764
}
5865
}
@@ -128,8 +135,8 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
128135
uint32_t *output_tokens,
129136
size_t output_tokens_len)
130137
{
131-
const struct llama_model *model = (const struct llama_model *)user_data;
132-
int r = llama_tokenize(model, (const char *) bytes, bytes_len,
138+
const struct llama_vocab *vocab = (const struct llama_vocab *)user_data;
139+
int r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
133140
(int32_t*)output_tokens, output_tokens_len, false, true);
134141
if (r < 0)
135142
return -r;
@@ -145,11 +152,13 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
145152
return llg_clone_tokenizer(tokenizer_cache);
146153
}
147154

148-
auto tok_eos = llama_token_eot(model);
155+
const struct llama_vocab *vocab = llama_model_get_vocab(model);
156+
157+
auto tok_eos = llama_vocab_eot(vocab);
149158
if (tok_eos == LLAMA_TOKEN_NULL)
150-
tok_eos = llama_token_eos(model);
159+
tok_eos = llama_vocab_eos(vocab);
151160

152-
size_t vocab_size = llama_n_vocab(model);
161+
size_t vocab_size = llama_vocab_n_tokens(vocab);
153162

154163
auto token_lens = new uint32_t[vocab_size];
155164
// we typically have ~7 bytes per token; let's go on the safe side here
@@ -165,12 +174,12 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
165174

166175
llama_token token = i;
167176
auto dp = (char *) token_bytes + offset;
168-
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false);
177+
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
169178
if (size < 0) {
170179
GGML_ABORT("llama_detokenize failed\n");
171180
}
172181
if (size == 0) {
173-
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true);
182+
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
174183
if (size < 0) {
175184
GGML_ABORT("llama_detokenize failed\n");
176185
}
@@ -194,7 +203,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
194203
/* .tokenize_assumes_string = */ false,
195204
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
196205
/* .use_approximate_greedy_tokenize_fn = */ false,
197-
/* .tokenize_user_data = */ model,
206+
/* .tokenize_user_data = */ vocab,
198207
};
199208

200209
char error_buffer[1024];
@@ -204,7 +213,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
204213
delete[] token_lens;
205214

206215
if (tokenizer == nullptr) {
207-
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer);
216+
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
208217
return tokenizer;
209218
}
210219

@@ -221,10 +230,13 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
221230
const char * grammar_kind, const char * grammar_data) {
222231
auto * ctx = new llama_sampler_llg;
223232

233+
const llama_vocab * vocab = llama_model_get_vocab(model);
234+
224235
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
225236
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
226237
*ctx = {
227238
/* .model = */ model,
239+
/* .vocab = */ vocab,
228240
/* .grammar_kind = */ grammar_kind,
229241
/* .grammar_data = */ grammar_data,
230242
/* .tokenizer = */ tokenizer,
@@ -235,6 +247,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
235247
} else {
236248
*ctx = {
237249
/* .model = */ model,
250+
/* .vocab = */ vocab,
238251
/* .grammar_kind = */ {},
239252
/* .grammar_data = */ {},
240253
/* .tokenizer = */ nullptr,

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
153153

154154
struct llama_sampler * grmr;
155155
if (params.grammar.compare(0, 4, "llg:") == 0) {
156-
#ifdef LLAMA_LLGUIDANCE
156+
#ifdef LLAMA_USE_LLGUIDANCE
157157
auto gp = params.grammar.find(':', 4);
158158
if (gp == std::string::npos) {
159159
GGML_ABORT("invalid serialized grammar");
@@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
162162
auto grm_data = params.grammar.c_str() + gp + 1;
163163
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
164164
#else
165-
GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled");
165+
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
166166
#endif
167167
} else {
168168
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");

common/sampling.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
102102

103103
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
104104
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
105+
106+
#ifdef LLAMA_USE_LLGUIDANCE
107+
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
108+
const char * grammar_kind, const char * grammar_data);
109+
#endif

0 commit comments

Comments
 (0)