Skip to content

Commit 39940e5

Browse files
authored
Algorithm rework
1. Scan token from top till the first non-penalizable 2. Remove the last captured token (the least probable above threshold) 3. Shift all tokens to override the remaining penalizable 4. Penalize and put them at the the bottom.
1 parent 094caea commit 39940e5

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

src/llama-sampling.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,37 +1096,40 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
10961096
// in case it's not sorted/recalculated yet
10971097
llama_sampler_softmax_impl(cur_p);
10981098

1099-
std::vector<llama_token_data> cur;
1100-
1101-
int removed = -1; // to keep one out
1099+
std::vector<llama_token_data> top_tkns;
11021100
int pos = 0;
11031101

1104-
// going through all candidates from back to front, easier to keep the last of probables
1105-
for (int i = (cur_p->size - 1); i >= 0; --i) {
1106-
if (cur_p->data[i].p >= ctx->threshold && cur_p->data[i].p <= ctx->threshold_max) {
1107-
++removed;
1108-
if (removed > 0) {
1109-
// .logits are used for sorting and calculating .p in llama_sample_softmax_impl
1110-
cur_p->data[i].logit = -999.0f;
1111-
cur.emplace_back(cur_p->data[i]);
1112-
pos = i;
1102+
for (size_t i = 0; i < cur_p->size; ++i) {
1103+
if (cur_p->data[i].p >= ctx->threshold) {
1104+
if (cur_p->data[i].p <= ctx->threshold_max) {
1105+
top_tkns.emplace_back(cur_p->data[i]);
1106+
// capture position of the first penalizable
1107+
if (pos == -1) pos = i;
11131108
}
1114-
}
1109+
} else break;
11151110
}
11161111

1117-
if (removed > 0) {
1118-
size_t size_new = cur_p->size - removed;
1112+
// check if there are enough penalizable tokens
1113+
if (top_tkns.size() >= 2) {
1114+
// keep the least probable from top ones
1115+
top_tkns.pop_back();
1116+
1117+
// define new size
1118+
size_t to_remove = top_tkns.size();
1119+
size_t size_new = cur_p->size - to_remove;
11191120

1120-
// shift tokens to remove the penalized ones
1121+
// shift tokens starting from pos
11211122
for (size_t i = pos; i < size_new - pos; ++i) {
1122-
cur_p->data[i] = cur_p->data[i + removed];
1123+
cur_p->data[i] = cur_p->data[i + to_remove];
11231124
}
11241125

1125-
// put the prenalized ones at the back
1126-
for (size_t i = 0; i < cur.size(); ++i) {
1127-
cur_p->data[cur_p->size - (1 + i)] = cur[i];
1126+
// penalize top tokens and put them at the back
1127+
for (size_t i = 0; i < top_tkns.size(); ++i) {
1128+
top_tkns[i].logit = -999.0f;
1129+
cur_p->data[cur_p->size - (1 + i)] = top_tkns[i];
11281130
}
11291131

1132+
// resize
11301133
if (size_new < ctx->min_keep) size_new = ctx->min_keep;
11311134
cur_p->size = size_new;
11321135
}

0 commit comments

Comments
 (0)