@@ -60,6 +60,17 @@ struct ring_buffer {
6060 return value;
6161 }
6262
63+ T pop_back () {
64+ if (sz == 0 ) {
65+ throw std::runtime_error (" ring buffer is empty" );
66+ }
67+ // Move pos backwards, wrapping around if necessary
68+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
69+ T value = data[pos];
70+ sz--;
71+ return value;
72+ }
73+
6374 const T & rat (size_t i) const {
6475 if (i >= sz) {
6576 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -275,6 +286,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
275286 llama_sampler_reset (gsmpl->chain );
276287}
277288
289+ void common_sampler_reinit_grammar (struct common_sampler * gsmpl, const struct llama_model * model, const char * grammar) {
290+ llama_sampler_reset (gsmpl->grmr );
291+
292+ gsmpl->grmr = llama_sampler_init_grammar (llama_model_get_vocab (model), grammar, " root" );
293+ }
294+
278295struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
279296 return new common_sampler {
280297 /* .params = */ gsmpl->params ,
@@ -428,6 +445,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
428445 return result;
429446}
430447
448+ const std::vector<llama_token> common_sampler_prev (common_sampler * gsmpl) {
449+ return gsmpl->prev .to_vector ();
450+ }
451+
452+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
453+ if (rollback_num > gsmpl->prev .size ()) {
454+ rollback_num = gsmpl->prev .size ();
455+ }
456+
457+ // continuously pop the last token
458+ for (int i = 0 ; i < rollback_num; i++) {
459+ gsmpl->prev .pop_back ();
460+ }
461+ }
462+
431463char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
432464 switch (cnstr) {
433465 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
0 commit comments