@@ -60,6 +60,17 @@ struct ring_buffer {
6060 return value;
6161 }
6262
63+ T pop_back () {
64+ if (sz == 0 ) {
65+ throw std::runtime_error (" ring buffer is empty" );
66+ }
67+ // Move pos backwards, wrapping around if necessary
68+ pos = (pos == 0 ) ? capacity - 1 : pos - 1 ;
69+ T value = data[pos];
70+ sz--;
71+ return value;
72+ }
73+
6374 const T & rat (size_t i) const {
6475 if (i >= sz) {
6576 throw std::runtime_error (" ring buffer: index out of bounds" );
@@ -163,15 +174,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
163174
164175 llama_sampler_chain_add (result->chain ,
165176 llama_sampler_init_penalties (
166- llama_n_vocab (model),
167- llama_token_eos (model),
168- llama_token_nl (model),
169- params.penalty_last_n ,
170- params.penalty_repeat ,
171- params.penalty_freq ,
172- params.penalty_present ,
173- params.penalize_nl ,
174- params.ignore_eos ));
177+ llama_n_vocab (model),
178+ llama_token_eos (model),
179+ llama_token_nl (model),
180+ params.penalty_last_n ,
181+ params.penalty_repeat ,
182+ params.penalty_freq ,
183+ params.penalty_present ,
184+ params.penalize_nl ,
185+ params.ignore_eos ));
175186
176187 if (params.mirostat == 0 ) {
177188 for (const auto & cnstr : params.samplers ) {
@@ -255,6 +266,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
255266 llama_sampler_reset (gsmpl->chain );
256267}
257268
269+ void common_sampler_reset_grammar (struct common_sampler * gsmpl) {
270+ llama_sampler_reset (gsmpl->grmr );
271+
272+ llama_sampler_reset (gsmpl->chain );
273+ }
274+
258275struct common_sampler * common_sampler_clone (common_sampler * gsmpl) {
259276 return new common_sampler {
260277 /* .params = */ gsmpl->params ,
@@ -369,6 +386,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
369386 return result;
370387}
371388
389+ const std::vector<llama_token>& common_sampler_prev (common_sampler * gsmpl) {
390+ return gsmpl->prev .data ;
391+ }
392+
393+ void common_sampler_rollback (common_sampler * gsmpl, int rollback_num) {
394+ if (rollback_num > gsmpl->prev .size ()) {
395+ rollback_num = gsmpl->prev .size ();
396+ }
397+
398+ // continuously pop the last token
399+ for (int i = 0 ; i < rollback_num; i++) {
400+ gsmpl->prev .pop_back ();
401+ }
402+ }
403+
372404char common_sampler_type_to_chr (enum common_sampler_type cnstr) {
373405 switch (cnstr) {
374406 case COMMON_SAMPLER_TYPE_DRY: return ' d' ;
@@ -472,4 +504,4 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
472504 }
473505
474506 return samplers;
475- }
507+ }
0 commit comments