@@ -61,6 +61,17 @@ struct ring_buffer {
6161 return value;
6262 }
6363
64+ T pop_back () {
65+ if (sz == 0 ) {
66+ throw std::runtime_error (" ring buffer is empty" );
67+ }
68+ // Move pos backwards, wrapping around if necessary
69+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
70+ T value = data[pos];
71+ sz--;
72+ return value;
73+ }
74+
6475 const T & rat (size_t i) const {
6576 if (i >= sz) {
6677 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -316,6 +327,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
316327 llama_sampler_reset (gsmpl->chain );
317328}
318329
330+ void common_sampler_reinit_grammar (struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
331+ llama_sampler_reset (gsmpl->grmr );
332+
333+ gsmpl->grmr = llama_sampler_init_grammar (llama_model_get_vocab (model), grammar, " root" );
334+ }
335+
319336struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
320337 return new common_sampler {
321338 /* .params = */ gsmpl->params ,
@@ -469,6 +486,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
469486 return result;
470487}
471488
489+ const std::vector<llama_token> common_sampler_prev (common_sampler * gsmpl) {
490+ return gsmpl->prev .to_vector ();
491+ }
492+
493+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
494+ if (rollback_num > gsmpl->prev .size ()) {
495+ rollback_num = gsmpl->prev .size ();
496+ }
497+
498+ // continuously pop the last token
499+ for (int i = 0 ; i < rollback_num; i++) {
500+ gsmpl->prev .pop_back ();
501+ }
502+ }
503+
472504char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
473505 switch (cnstr) {
474506 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
0 commit comments