Skip to content

Commit 41e1665

Browse files
authored
First fixes by comments
Still need to look into sorting
1 parent db54ac5 commit 41e1665

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
975975
).set_sparam());
976976
add_opt(llama_arg(
977977
{"--xtc-t"}, "N",
978-
format("xtc threshold (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_t),
978+
format("xtc threshold (default: %.1f, 0.0 or 1.0 = disabled)", (double)params.sparams.xtc_t),
979979
[](gpt_params & params, const std::string & value) {
980980
params.sparams.xtc_t = std::stof(value);
981981
}

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ struct gpt_sampler_params {
110110
float top_p = 0.95f; // 1.0 = disabled
111111
float min_p = 0.05f; // 0.0 = disabled
112112
float xtc_p = 0.50f; // 0.0 = disabled
113-
float xtc_t = 0.10f; // 1.0 = disabled
113+
float xtc_t = 0.10f; // 0.0 or 1.0 = disabled
114114
float xtc_t_max = 1.00f; // 0.0 = disabled
115115
float tfs_z = 1.00f; // 1.0 = disabled
116116
float typ_p = 1.00f; // typical_p, 1.0 = disabled

src/llama-sampling.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,37 +1075,43 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/
10751075
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
10761076
const auto * ctx = (llama_sampler_xtc *) smpl->ctx;
10771077

1078-
if (ctx->probability <= 0.0f || ctx->threshold <= 0.0f || cur_p->size <= 1 || ctx->min_keep <= 2) {
1078+
if (ctx->probability <= 0.0f
1079+
|| ctx->threshold <= 0.0f
1080+
|| ctx->threshold >= 1.0f
1081+
|| ctx->threshold_max <= 0.0f
1082+
|| ctx->threshold_max <= ctx->threshold
1083+
|| cur_p->size <= 2
1084+
|| ctx->min_keep <= 2) {
10791085
return;
10801086
}
10811087

10821088
std::random_device rd;
1083-
float chance = (float)(rd()%100)/100;
1089+
float chance = (float)(rd()%100 - 1)/100;
10841090
if (chance > ctx->probability) return;
1091+
10851092
// in case it's not sorted/recalculated yet
10861093
llama_sampler_softmax_impl(cur_p);
10871094

1088-
int removed = 0;
1095+
int found = 0;
10891096
// going through all candidates from back to front, easier to keep the last of probables
10901097
for (int i = (cur_p->size - 1); i >= 0; --i) {
10911098
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
1092-
++removed;
1093-
if (removed > 1) {
1099+
++found;
1100+
if (found > 1) {
10941101
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
10951102
cur_p->data[i].logit = -999.0f;
10961103
}
10971104
}
10981105
}
10991106

1100-
if (removed > 1) {
1107+
if (found > 1) {
11011108
// sorting with new logits, ex-last probable will be the first anyway
11021109
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
11031110
return a.logit > b.logit;
11041111
});
1105-
cur_p->sorted = true;
11061112

11071113
// resizing now that penalized tokens are at the back
1108-
cur_p->size = cur_p->size - removed + 1;
1114+
cur_p->size = cur_p->size - found + 1;
11091115
}
11101116
}
11111117

0 commit comments

Comments
 (0)