@@ -60,6 +60,17 @@ struct ring_buffer {
6060 return value;
6161 }
6262
63+ T pop_back () {
64+ if (sz == 0 ) {
65+ throw std::runtime_error (" ring buffer is empty" );
66+ }
67+ // Move pos backwards, wrapping around if necessary
68+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
69+ T value = data[pos];
70+ sz--;
71+ return value;
72+ }
73+
6374 const T & rat (size_t i) const {
6475 if (i >= sz) {
6576 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -248,6 +259,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
248259 llama_sampler_reset (gsmpl->chain );
249260}
250261
262+ void common_sampler_reinit_grammar (struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
263+ llama_sampler_reset (gsmpl->grmr );
264+
265+ gsmpl->grmr = llama_sampler_init_grammar (model, grammar, " root" );
266+ }
267+
251268struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
252269 return new common_sampler {
253270 /* .params = */ gsmpl->params ,
@@ -401,6 +418,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
401418 return result;
402419}
403420
421+ const std::vector<llama_token> common_sampler_prev (common_sampler * gsmpl) {
422+ return gsmpl->prev .to_vector ();
423+ }
424+
425+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
426+ if (rollback_num > gsmpl->prev .size ()) {
427+ rollback_num = gsmpl->prev .size ();
428+ }
429+
430+ // continuously pop the last token
431+ for (int i = 0 ; i < rollback_num; i++) {
432+ gsmpl->prev .pop_back ();
433+ }
434+ }
435+
404436char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
405437 switch (cnstr) {
406438 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
0 commit comments