Skip to content

Commit 2572396

Browse files
committed
added layla used functions
1 parent 35e499a commit 2572396

File tree

4 files changed

+1503
-1623
lines changed

4 files changed

+1503
-1623
lines changed

common/sampling.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ struct ring_buffer {
6060
return value;
6161
}
6262

63+
T pop_back() {
64+
if (sz == 0) {
65+
throw std::runtime_error("ring buffer is empty");
66+
}
67+
// Move pos backwards, wrapping around if necessary
68+
pos = (pos == 0) ? capacity - 1 : pos - 1;
69+
T value = data[pos];
70+
sz--;
71+
return value;
72+
}
73+
6374
const T & rat(size_t i) const {
6475
if (i >= sz) {
6576
throw std::runtime_error("ring buffer: index out of bounds");
@@ -163,15 +174,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
163174

164175
llama_sampler_chain_add(result->chain,
165176
llama_sampler_init_penalties(
166-
llama_n_vocab (model),
167-
llama_token_eos(model),
168-
llama_token_nl (model),
169-
params.penalty_last_n,
170-
params.penalty_repeat,
171-
params.penalty_freq,
172-
params.penalty_present,
173-
params.penalize_nl,
174-
params.ignore_eos));
177+
llama_n_vocab (model),
178+
llama_token_eos(model),
179+
llama_token_nl (model),
180+
params.penalty_last_n,
181+
params.penalty_repeat,
182+
params.penalty_freq,
183+
params.penalty_present,
184+
params.penalize_nl,
185+
params.ignore_eos));
175186

176187
if (params.mirostat == 0) {
177188
for (const auto & cnstr : params.samplers) {
@@ -255,6 +266,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
255266
llama_sampler_reset(gsmpl->chain);
256267
}
257268

269+
void common_sampler_reset_grammar(struct common_sampler * gsmpl) {
270+
llama_sampler_reset(gsmpl->grmr);
271+
272+
llama_sampler_reset(gsmpl->chain);
273+
}
274+
258275
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
259276
return new common_sampler {
260277
/* .params = */ gsmpl->params,
@@ -369,6 +386,21 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
369386
return result;
370387
}
371388

389+
const std::vector<llama_token>& common_sampler_prev(common_sampler * gsmpl) {
390+
return gsmpl->prev.data;
391+
}
392+
393+
void common_sampler_rollback(common_sampler * gsmpl, int rollback_num) {
394+
if(rollback_num > gsmpl->prev.size()) {
395+
rollback_num = gsmpl->prev.size();
396+
}
397+
398+
// continuously pop the last token
399+
for(int i = 0; i < rollback_num; i++) {
400+
gsmpl->prev.pop_back();
401+
}
402+
}
403+
372404
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
373405
switch (cnstr) {
374406
case COMMON_SAMPLER_TYPE_DRY: return 'd';
@@ -472,4 +504,4 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
472504
}
473505

474506
return samplers;
475-
}
507+
}

common/sampling.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void common_sampler_free(struct common_sampler * gsmpl);
4343
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
4444
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
4545
void common_sampler_reset (struct common_sampler * gsmpl);
46+
void common_sampler_reset_grammar(struct common_sampler * gsmpl);
4647
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
4748

4849
// arguments can be nullptr to skip printing
@@ -75,9 +76,11 @@ std::string common_sampler_print(const struct common_sampler * gsmpl);
7576

7677
// get a string representation of the last accepted tokens
7778
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
79+
const std::vector<llama_token>& common_sampler_prev(common_sampler * gsmpl);
80+
void common_sampler_rollback(common_sampler * gsmpl, int rollback_num);
7881

7982
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
8083
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
8184

8285
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
83-
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
86+
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);

0 commit comments

Comments
 (0)