Skip to content

Commit 4cc0a8a

Browse files
committed
fixed grammar sampler reinit
1 parent 2572396 commit 4cc0a8a

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

common/sampling.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,14 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
266266
llama_sampler_reset(gsmpl->chain);
267267
}
268268

269-
void common_sampler_reset_grammar(struct common_sampler * gsmpl) {
269+
void common_sampler_reinit_grammar(struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
270270
llama_sampler_reset(gsmpl->grmr);
271271

272-
llama_sampler_reset(gsmpl->chain);
272+
gsmpl->grmr = llama_sampler_init_grammar(model, grammar, "root");
273+
}
274+
275+
void common_sampler_reset_grammar(struct common_sampler * gsmpl) {
276+
llama_sampler_reset(gsmpl->grmr);
273277
}
274278

275279
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {

common/sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void common_sampler_free(struct common_sampler * gsmpl);
4343
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
4444
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
4545
void common_sampler_reset (struct common_sampler * gsmpl);
46+
void common_sampler_reinit_grammar(struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar);
4647
void common_sampler_reset_grammar(struct common_sampler * gsmpl);
4748
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
4849

0 commit comments

Comments
 (0)