Skip to content

Commit 52d12e1

Browse files
committed
XCT: do not penalize the last token, sort conditionally
1 parent b3dce58 commit 52d12e1

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

base/llama-addon.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,18 @@ void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array *
4545

4646
const int64_t t_start_sample_us = ggml_time_us();
4747

48-
for (size_t i = 0; i < candidates->size ; ++i) {
48+
for (size_t i = 0; i < (candidates->size - 1); ++i) { // let's not penalize the last candidate even if it can be, may help with spaces
4949
if (candidates->data[i].p >= xtc_threshold) {
5050
std::srand(std::time(nullptr));
51-
if(std::rand() <= xtc_probability) candidates->data[i].p *= 0;
51+
if (std::rand() <= xtc_probability) {
52+
candidates->data[i].p *= 0;
53+
candidates->sorted = false;
54+
}
5255
}
5356
}
5457

55-
candidates->sorted = false;
56-
57-
// Re-normalize probabilities
58-
llama_sample_softmax(ctx, candidates);
58+
// Re-normalize probabilities if required
59+
if (candidates->sorted == false) llama_sample_softmax(ctx, candidates);
5960

6061
llama_set_time(ctx, t_start_sample_us);
6162
}

0 commit comments

Comments
 (0)