@@ -61,6 +61,17 @@ struct ring_buffer {
6161 return value;
6262 }
6363
64+ T pop_back () {
65+ if (sz == 0 ) {
66+ throw std::runtime_error (" ring buffer is empty" );
67+ }
68+ // Move pos backwards, wrapping around if necessary
69+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
70+ T value = data[pos];
71+ sz--;
72+ return value;
73+ }
74+
6475 const T & rat (size_t i) const {
6576 if (i >= sz) {
6677 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -313,6 +324,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
313324 llama_sampler_reset (gsmpl->chain );
314325}
315326
327+ void common_sampler_reinit_grammar (struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
328+ llama_sampler_reset (gsmpl->grmr );
329+
330+ gsmpl->grmr = llama_sampler_init_grammar (llama_model_get_vocab (model), grammar, " root" );
331+ }
332+
316333struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
317334 return new common_sampler {
318335 /* .params = */ gsmpl->params ,
@@ -466,6 +483,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
466483 return result;
467484}
468485
486+ const std::vector<llama_token> common_sampler_prev (common_sampler * gsmpl) {
487+ return gsmpl->prev .to_vector ();
488+ }
489+
490+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
491+ if (rollback_num > gsmpl->prev .size ()) {
492+ rollback_num = gsmpl->prev .size ();
493+ }
494+
495+ // continuously pop the last token
496+ for (int i = 0 ; i < rollback_num; i++) {
497+ gsmpl->prev .pop_back ();
498+ }
499+ }
500+
469501char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
470502 switch (cnstr) {
471503 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
0 commit comments