Skip to content

Commit 6b69d0b

Browse files
committed
XTC: added xtc_threshold_max parameter as an upper limit
* 1.0 by default, so doesn't affect anything * can be used to eliminate tokens within a range if you are sure that some top tokens are not clichéd (in finetuned models, for example)
1 parent 44eb8c9 commit 6b69d0b

File tree

7 files changed

+11
-5
lines changed

7 files changed

+11
-5
lines changed

base/llama-addon.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
#include <type_traits>
3737
#include <unordered_map>
3838

39-
void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, bool xtc_probability_once, int xtc_min, size_t min_keep) {
40-
if (xtc_probability <= 0.0f || xtc_threshold <= 0.0f || candidates->size <= 1) {
39+
void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, float xtc_threshold_max, bool xtc_probability_once, int xtc_min, size_t min_keep) {
40+
if (xtc_probability <= 0.0f || xtc_threshold <= 0.0f || xtc_threshold_max == xtc_threshold || xtc_min < 1 || candidates->size <= 1) {
4141
return;
4242
}
4343

@@ -52,7 +52,7 @@ void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array *
5252
size_t removed = 0;
5353
// going through all candidates to correctly trigget the effect
5454
for (size_t i = 0; i < candidates->size; ++i) {
55-
if (candidates->data[i].p >= xtc_threshold) {
55+
if (candidates->data[i].p >= xtc_threshold && candidates->data[i].p <= xtc_threshold_max) {
5656
if (id_first == -1) {
5757
id_first = i;
5858
++removed;

base/llama-addon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
llama_token_data_array * candidates,
88
float xtc_probability,
99
float xtc_threshold,
10+
float xtc_threshold_max,
1011
bool xtc_probability_once,
1112
int xtc_min,
1213
size_t min_keep);

base/sampling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void sampler_queue(
146146
const float p_step = params.p_step;
147147
const float xtc_probability = params.xtc_probability;
148148
const float xtc_threshold = params.xtc_threshold;
149+
const float xtc_threshold_max = params.xtc_threshold_max;
149150
const float xtc_probability_once = params.xtc_probability_once;
150151
const float xtc_min = params.xtc_min;
151152
const std::string samplers_sequence = params.samplers_sequence;
@@ -158,7 +159,7 @@ void sampler_queue(
158159
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
159160
case 'm': llama_sample_min_p_addon (ctx_main, &cur_p, min_p, min_keep); break;
160161
case 's': llama_sample_p_step_addon (ctx_main, &cur_p, p_step, min_keep); break;
161-
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, xtc_probability_once, xtc_min, min_keep); break;
162+
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, xtc_threshold_max, xtc_probability_once, xtc_min, min_keep); break;
162163
case 't': {
163164
if (dynatemp_range>0)
164165
{

base/sampling.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ typedef struct llama_sampling_params {
5959
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
6060
float xtc_probability = 0.5; // probability of removing a top token
6161
float xtc_threshold = 0.1; // minimum tokens probablitity for this to run
62+
float xtc_threshold_max = 1.0; // maximum tokens probablitity for this to run
6263
bool xtc_probability_once = false; // should we calculate chances one or for each token
6364
int xtc_min = 2; // minimum number of penalizeable tokens
6465
std::string samplers_sequence = "kfypmts"; // top_k, tail_free, typical_p, top_p, min_p, temp, p_step

chat_plain.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ class chat
557557
case 'f': result += name_tfs_z; if (params.sparams.tfs_z != paramsDefault.sparams.tfs_z) result += std::format("={:.3f}",params.sparams.tfs_z); break;
558558
case 'y': result += name_typical_p; if (params.sparams.typical_p != paramsDefault.sparams.typical_p) result += std::format("={:.3f}",params.sparams.typical_p); break;
559559
case 's': result += name_p_step; if (params.sparams.p_step != paramsDefault.sparams.p_step) result += std::format("={:.3f}",params.sparams.p_step); break;
560-
case 'x': result += std::format("xtc={:.3f}-{:.03f}%",params.sparams.xtc_threshold,params.sparams.xtc_probability); break;
560+
case 'x': result += std::format("xtc={:.3f}-{:.3f}({}%/{})",params.sparams.xtc_threshold,params.sparams.xtc_threshold_max,params.sparams.xtc_probability*100,params.sparams.xtc_min); if (params.sparams.xtc_probability_once) result += "once"; else result += "each"; break;
561561
case 'p': result += name_top_p; if (params.sparams.top_p != paramsDefault.sparams.top_p) result += std::format("={:.3f}",params.sparams.top_p); break;
562562
case 'm': result += name_min_p; if (params.sparams.min_p != paramsDefault.sparams.min_p) result += std::format("={:.3f}",params.sparams.min_p); break;
563563
case 't': {

include/jsonParams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
519519
if (checkJNum(config, "tfs_z")) params.sparams.tfs_z = config["tfs_z"];
520520
if (checkJNum(config, "xtc_probability")) params.sparams.xtc_probability = config["xtc_probability"];
521521
if (checkJNum(config, "xtc_threshold")) params.sparams.xtc_threshold = config["xtc_threshold"];
522+
if (checkJNum(config, "xtc_threshold_max")) params.sparams.xtc_threshold_max = config["xtc_threshold_max"];
522523
if (checkJNum(config, "xtc_min")) params.sparams.xtc_min = config["xtc_min"];
523524
if (checkJBool(config, "xtc_probability_once")) params.sparams.xtc_probability_once = config["xtc_probability_once"];
524525

thread_chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,7 @@ struct configurableChat{
13131313
} else if (params.sparams.p_step != paramsDefault.sparams.p_step) modelConfig[model]["p_step"] = params.sparams.p_step;
13141314
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) modelConfig[model]["xtc_probability"] = params.sparams.xtc_probability;
13151315
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) modelConfig[model]["xtc_threshold"] = params.sparams.xtc_threshold;
1316+
if (params.sparams.xtc_threshold_max != paramsDefault.sparams.xtc_threshold_max) modelConfig[model]["xtc_threshold_max"] = params.sparams.xtc_threshold_max;
13161317
if (params.sparams.xtc_min != paramsDefault.sparams.xtc_min) modelConfig[model]["xtc_min"] = params.sparams.xtc_min;
13171318
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) modelConfig[model]["xtc_probability_once"] = params.sparams.xtc_probability_once;
13181319
// penalties
@@ -1424,6 +1425,7 @@ struct configurableChat{
14241425
if (params.sparams.p_step != paramsDefault.sparams.p_step) newCard["p_step"] = params.sparams.p_step;
14251426
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) newCard["xtc_probability"] = params.sparams.xtc_probability;
14261427
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) newCard["xtc_threshold"] = params.sparams.xtc_threshold;
1428+
if (params.sparams.xtc_threshold_max != paramsDefault.sparams.xtc_threshold_max) newCard["xtc_threshold_max"] = params.sparams.xtc_threshold_max;
14271429
if (params.sparams.xtc_min != paramsDefault.sparams.xtc_min) newCard["xtc_min"] = params.sparams.xtc_min;
14281430
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) newCard["xtc_probability_once"] = params.sparams.xtc_probability_once;
14291431
//penalties

0 commit comments

Comments
 (0)