Skip to content

Commit 44eb8c9

Browse files
committed
XTC: all candidates are scanned now, better sorting
* going through all candidates to detect all tokens above threshold * sorting prioritizes last token if equal .logit values (ensures the most probable token will be penalized out of two penalizeable if only two tokens are available)
1 parent 302e9c9 commit 44eb8c9

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

base/llama-addon.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array *
5050
const int64_t t_start_sample_us = ggml_time_us();
5151
int id_first = -1;
5252
size_t removed = 0;
53-
for (size_t i = 0; i < (candidates->size - 1); ++i) {
53+
// going through all candidates to correctly trigget the effect
54+
for (size_t i = 0; i < candidates->size; ++i) {
5455
if (candidates->data[i].p >= xtc_threshold) {
5556
if (id_first == -1) {
5657
id_first = i;
@@ -66,13 +67,17 @@ void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array *
6667

6768
if (removed >= xtc_min) {
6869
// penalizing by first id
69-
if (xtc_probability_once || chance <= xtc_probability) candidates->data[id_first].logit = -999.0f;
70-
// sorting with new logits
70+
if (xtc_probability_once || chance <= xtc_probability) {
71+
candidates->data[id_first].logit = -999.0f;
72+
}
73+
// sorting with new logits, but prioritizing last token since we'll resize later
7174
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
72-
return a.logit > b.logit;
75+
return a.logit >= b.logit;
7376
});
74-
//resizing now that penalized tokens are at the back
75-
candidates->size = candidates->size - removed;
77+
78+
// resizing now that penalized tokens are at the back, but leave at least 1 token
79+
// this ensures that if only 2 tokens are present, at least one (more probable) is penalized
80+
candidates->size = (candidates->size > removed ? candidates->size - removed : 1);
7681
}
7782
llama_set_time(ctx, t_start_sample_us);
7883
}

0 commit comments

Comments
 (0)