Skip to content

Commit 1f2add8

Browse files
committed
XTC: Added xtc_probability_once parameter
* Allows to choose between a quick random choice once (like in original) and per-token random choices.
1 parent 363bcbc commit 1f2add8

File tree

8 files changed

+56
-19
lines changed

8 files changed

+56
-19
lines changed

base/common.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,16 @@ int get_math_cpu_count() {
180180
return get_num_physical_cores();
181181
}
182182

183+
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
184+
if (search.empty()) {
185+
return; // Avoid infinite loop if 'search' is an empty string
186+
}
187+
size_t pos = 0;
188+
while ((pos = s.find(search, pos)) != std::string::npos) {
189+
s.replace(pos, search.length(), replace);
190+
pos += replace.length();
191+
}
192+
}
183193

184194
void process_escapes(std::string& input) {
185195
std::size_t input_len = input.length();

base/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ std::string get_system_info(const gpt_params & params);
156156

157157
std::string gpt_random_prompt(std::mt19937 & rng);
158158

159+
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
160+
159161
void process_escapes(std::string& input);
160162

161163
//

base/llama-addon.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,44 @@
3636
#include <type_traits>
3737
#include <unordered_map>
3838

39-
void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, size_t min_keep) {
39+
void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, float xtc_probability_once, size_t min_keep) {
4040
if (xtc_probability <= 0.0f || xtc_threshold <= 0.0f || candidates->size <= 1) {
4141
return;
4242
}
4343

44+
std::random_device rd;
45+
float chance = (float)(rd()%100)/100;
46+
//printf("\nChance = %f; ", chance);
47+
if (xtc_probability_once && chance > xtc_probability) return;
48+
4449
llama_sample_softmax(nullptr, candidates);
4550

4651
const int64_t t_start_sample_us = ggml_time_us();
4752
size_t removed = 0;
4853
for (size_t i = 0; i < (candidates->size - 1); ++i) {
4954
if (candidates->data[i].p >= xtc_threshold) {
50-
std::random_device rd;
51-
float chance = (float)(rd()%100)/100;
52-
53-
if (chance <= xtc_probability) {
55+
if (xtc_probability_once || chance <= xtc_probability) {
5456
candidates->data[i].logit = -999.0f; // .p will be recalculated in llama_sample_softmax_impl later based on .logit, so we need to change these
5557
++removed;
58+
if (!xtc_probability_once) {
59+
chance = (float)(rd()%100)/100;
60+
printf(" chance = %f; ", chance);
61+
}
5662
}
5763
}
5864
}
65+
66+
//printf("\nPresort (size %zu): %f, %f, %f", candidates->size, candidates->data[0].p, candidates->data[1].p, candidates->data[2].p);
67+
5968
// sorting with new logits
6069
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
6170
return a.logit > b.logit;
6271
});
6372
//resizing now that penalized tokens are at the back
6473
candidates->size = candidates->size - removed;
6574

75+
//printf("\nSort (size %zu): %f, %f, %f\n", candidates->size, candidates->data[0].p, candidates->data[1].p, candidates->data[2].p);
76+
6677
llama_set_time(ctx, t_start_sample_us);
6778
}
6879

base/llama-addon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
llama_token_data_array * candidates,
88
float xtc_probability,
99
float xtc_threshold,
10+
float xtc_probability_once,
1011
size_t min_keep);
1112

1213
/// @details P-Step sampling as described in [THIS PR]

base/sampling.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,20 @@ void sampler_queue(
134134
llama_token_data_array & cur_p,
135135
size_t & min_keep) {
136136

137-
const float temp = params.temp;
138-
const float smoothing_factor = params.smoothing_factor;
139-
const float smoothing_curve = params.smoothing_curve;
140-
const float dynatemp_range = params.dynatemp_range;
141-
const int32_t top_k = params.top_k;
142-
const float top_p = params.top_p;
143-
const float min_p = params.min_p;
144-
const float tfs_z = params.tfs_z;
145-
const float typical_p = params.typical_p;
146-
const float p_step = params.p_step;
147-
const float xtc_probability = params.xtc_probability;
148-
const float xtc_threshold = params.xtc_threshold;
149-
const std::string samplers_sequence = params.samplers_sequence;
137+
const float temp = params.temp;
138+
const float smoothing_factor = params.smoothing_factor;
139+
const float smoothing_curve = params.smoothing_curve;
140+
const float dynatemp_range = params.dynatemp_range;
141+
const int32_t top_k = params.top_k;
142+
const float top_p = params.top_p;
143+
const float min_p = params.min_p;
144+
const float tfs_z = params.tfs_z;
145+
const float typical_p = params.typical_p;
146+
const float p_step = params.p_step;
147+
const float xtc_probability = params.xtc_probability;
148+
const float xtc_threshold = params.xtc_threshold;
149+
const float xtc_probability_once = params.xtc_probability_once;
150+
const std::string samplers_sequence = params.samplers_sequence;
150151

151152
for (auto s : samplers_sequence){
152153
switch (s){
@@ -156,7 +157,7 @@ void sampler_queue(
156157
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
157158
case 'm': llama_sample_min_p_addon (ctx_main, &cur_p, min_p, min_keep); break;
158159
case 's': llama_sample_p_step_addon (ctx_main, &cur_p, p_step, min_keep); break;
159-
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
160+
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, xtc_probability_once, min_keep); break;
160161
case 't': {
161162
if (dynatemp_range>0)
162163
{

base/sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ typedef struct llama_sampling_params {
5959
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
6060
float xtc_probability = 0.5; // probability of removing a top token
6161
float xtc_threshold = 0.1; // minimum tokens probablitity for this to run
62+
bool xtc_probability_once = false; // should we calculate chances one or for each token
6263
std::string samplers_sequence = "kfypmts"; // top_k, tail_free, typical_p, top_p, min_p, temp, p_step
6364

6465
std::string grammar; // optional BNF-like grammar to constrain sampling

include/jsonParams.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ static bool checkJNum(nlohmann::json& config, std::string name){
111111
return false;
112112
}
113113

114+
static bool checkJBool(nlohmann::json& config, std::string name){
115+
if(config.contains(name)){
116+
if(config[name].is_boolean()) return true;
117+
}
118+
119+
return false;
120+
}
121+
114122
static bool checkJObj(nlohmann::json& config, std::string name){
115123
if(config.contains(name)){
116124
if(config[name].is_object()) {
@@ -511,6 +519,7 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
511519
if (checkJNum(config, "tfs_z")) params.sparams.tfs_z = config["tfs_z"];
512520
if (checkJNum(config, "xtc_probability")) params.sparams.xtc_probability = config["xtc_probability"];
513521
if (checkJNum(config, "xtc_threshold")) params.sparams.xtc_threshold = config["xtc_threshold"];
522+
if (checkJBool(config, "xtc_probability_once")) params.sparams.xtc_probability_once = config["xtc_probability_once"];
514523

515524
//penalties
516525
if (checkJNum(config, "repeat_penalty")) params.sparams.penalty_repeat = config["repeat_penalty"];

thread_chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,7 @@ struct configurableChat{
13131313
} else if (params.sparams.p_step != paramsDefault.sparams.p_step) modelConfig[model]["p_step"] = params.sparams.p_step;
13141314
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) modelConfig[model]["xtc_probability"] = params.sparams.xtc_probability;
13151315
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) modelConfig[model]["xtc_threshold"] = params.sparams.xtc_threshold;
1316+
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) modelConfig[model]["xtc_probability_once"] = params.sparams.xtc_probability_once;
13161317
// penalties
13171318
if (params.sparams.penalty_repeat != paramsDefault.sparams.penalty_repeat) modelConfig[model]["repeat_penalty"] = params.sparams.penalty_repeat;
13181319
if (params.sparams.penalty_threshold != paramsDefault.sparams.penalty_threshold) modelConfig[model]["penalty_threshold"] = params.sparams.penalty_threshold;
@@ -1422,6 +1423,7 @@ struct configurableChat{
14221423
if (params.sparams.p_step != paramsDefault.sparams.p_step) newCard["p_step"] = params.sparams.p_step;
14231424
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) newCard["xtc_probability"] = params.sparams.xtc_probability;
14241425
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) newCard["xtc_threshold"] = params.sparams.xtc_threshold;
1426+
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) newCard["xtc_probability_once"] = params.sparams.xtc_probability_once;
14251427
//penalties
14261428
if (params.sparams.penalty_threshold != paramsDefault.sparams.penalty_threshold) newCard["penalty_threshold"] = params.sparams.penalty_threshold;
14271429
if (params.sparams.penalty_repeat != paramsDefault.sparams.penalty_repeat) newCard["repeat_penalty"] = params.sparams.penalty_repeat;

0 commit comments

Comments
 (0)