Skip to content

Commit 0615989

Browse files
committed
cont : avoid extra loop in temperature sampler for sub-zero temp
ggml-ci
1 parent 4a5b587 commit 0615989

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

src/llama-sampling.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,15 @@ static void llama_log_softmax(float * array, size_t size) {
6666
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
6767
if (temp <= 0.0f) {
6868
// find the token with the highest logit and set the rest to -inf
69-
llama_token max_id = cur_p->data[0].id;
70-
float max_logit = cur_p->data[0].logit;
69+
size_t max_i = 0;
70+
float max_l = cur_p->data[0].logit;
7171

7272
for (size_t i = 1; i < cur_p->size; ++i) {
73-
if (cur_p->data[i].logit > max_logit) {
74-
max_id = cur_p->data[i].id;
75-
max_logit = cur_p->data[i].logit;
76-
}
77-
}
78-
79-
for (size_t i = 0; i < cur_p->size; ++i) {
80-
if (cur_p->data[i].id != max_id) {
73+
if (cur_p->data[i ].logit > max_l) {
74+
cur_p->data[max_i].logit = -INFINITY;
75+
max_i = i;
76+
max_l = cur_p->data[i].logit;
77+
} else {
8178
cur_p->data[i].logit = -INFINITY;
8279
}
8380
}

0 commit comments

Comments
 (0)