@@ -60,6 +60,17 @@ struct ring_buffer {
6060 return value;
6161 }
6262
63+ T pop_back () {
64+ if (sz == 0 ) {
65+ throw std::runtime_error (" ring buffer is empty" );
66+ }
67+ // Move pos backwards, wrapping around if necessary
68+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
69+ T value = data[pos];
70+ sz--;
71+ return value;
72+ }
73+
6374 const T & rat (size_t i) const {
6475 if (i >= sz) {
6576 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -163,15 +174,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
163174
164175 llama_sampler_chain_add (result->chain ,
165176 llama_sampler_init_penalties (
166- llama_n_vocab (model),
167- llama_token_eos (model),
168- llama_token_nl (model),
169- params.penalty_last_n ,
170- params.penalty_repeat ,
171- params.penalty_freq ,
172- params.penalty_present ,
173- params.penalize_nl ,
174- params.ignore_eos ));
177+ llama_n_vocab (model),
178+ llama_token_eos (model),
179+ llama_token_nl (model),
180+ params.penalty_last_n ,
181+ params.penalty_repeat ,
182+ params.penalty_freq ,
183+ params.penalty_present ,
184+ params.penalize_nl ,
185+ params.ignore_eos ));
175186
176187 if (params.mirostat == 0 ) {
177188 for (const auto & cnstr : params.samplers ) {
@@ -252,6 +263,16 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
252263 llama_sampler_reset (gsmpl->chain );
253264}
254265
266+ void common_sampler_reinit_grammar (struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
267+ llama_sampler_reset (gsmpl->grmr );
268+
269+ gsmpl->grmr = llama_sampler_init_grammar (model, grammar, " root" );
270+ }
271+
272+ void common_sampler_reset_grammar (struct common_sampler * gsmpl) {
273+ llama_sampler_reset (gsmpl->grmr );
274+ }
275+
255276struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
256277 return new common_sampler {
257278 /* .params = */ gsmpl->params ,
@@ -366,6 +387,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
366387 return result;
367388}
368389
390+ const std::vector<llama_token> common_sampler_prev (common_sampler * gsmpl) {
391+ return gsmpl->prev .to_vector ();
392+ }
393+
394+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
395+ if (rollback_num > gsmpl->prev .size ()) {
396+ rollback_num = gsmpl->prev .size ();
397+ }
398+
399+ // continuously pop the last token
400+ for (int i = 0 ; i < rollback_num; i++) {
401+ gsmpl->prev .pop_back ();
402+ }
403+ }
404+
369405char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
370406 switch (cnstr) {
371407 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
0 commit comments