Skip to content

Commit 4a5b587

Browse files
committed
llama : handle temp <= 0.0 in the temp_ext sampler too
ggml-ci
1 parent cd97850 commit 4a5b587

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

src/llama-sampling.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) {
6363
}
6464
*/
6565

66+
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67+
if (temp <= 0.0f) {
68+
// find the token with the highest logit and set the rest to -inf
69+
llama_token max_id = cur_p->data[0].id;
70+
float max_logit = cur_p->data[0].logit;
71+
72+
for (size_t i = 1; i < cur_p->size; ++i) {
73+
if (cur_p->data[i].logit > max_logit) {
74+
max_id = cur_p->data[i].id;
75+
max_logit = cur_p->data[i].logit;
76+
}
77+
}
78+
79+
for (size_t i = 0; i < cur_p->size; ++i) {
80+
if (cur_p->data[i].id != max_id) {
81+
cur_p->data[i].logit = -INFINITY;
82+
}
83+
}
84+
85+
return;
86+
}
87+
88+
for (size_t i = 0; i < cur_p->size; ++i) {
89+
cur_p->data[i].logit /= temp;
90+
}
91+
}
92+
6693
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
6794
GGML_ASSERT(cur_p->size > 0);
6895

@@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
916943
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
917944
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
918945

919-
if (ctx->temp <= 0.0f) {
920-
// find the token with the highest logit and set the rest to -inf
921-
llama_token max_id = cur_p->data[0].id;
922-
float max_logit = cur_p->data[0].logit;
923-
924-
for (size_t i = 1; i < cur_p->size; ++i) {
925-
if (cur_p->data[i].logit > max_logit) {
926-
max_id = cur_p->data[i].id;
927-
max_logit = cur_p->data[i].logit;
928-
}
929-
}
930-
931-
for (size_t i = 0; i < cur_p->size; ++i) {
932-
if (cur_p->data[i].id != max_id) {
933-
cur_p->data[i].logit = -INFINITY;
934-
}
935-
}
936-
937-
return;
938-
}
939-
940-
for (size_t i = 0; i < cur_p->size; ++i) {
941-
cur_p->data[i].logit /= ctx->temp;
942-
}
946+
llama_sampler_temp_impl(cur_p, ctx->temp);
943947
}
944948

945949
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
10241028
#endif
10251029

10261030
// Apply the dynamically calculated temperature scaling
1027-
for (size_t i = 0; i < cur_p->size; ++i) {
1028-
cur_p->data[i].logit /= dyn_temp;
1029-
}
1031+
llama_sampler_temp_impl(cur_p, dyn_temp);
10301032

10311033
// Re-compute softmax probabilities after scaling logits with dynamic temperature
10321034
const double max_l_double = cur_p->data[0].logit;
@@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
10501052
}
10511053
#endif
10521054
} else {
1053-
for (size_t i = 0; i < cur_p->size; ++i) {
1054-
cur_p->data[i].logit /= ctx->temp;
1055-
}
1055+
llama_sampler_temp_impl(cur_p, ctx->temp);
10561056
}
10571057
}
10581058

tests/test-sampling.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
7070
tester.check();
7171
}
7272

73+
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
74+
sampler_tester tester(probs, probs_expected);
75+
76+
DUMP(&tester.cur_p);
77+
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
78+
tester.apply(llama_sampler_init_dist (0));
79+
DUMP(&tester.cur_p);
80+
81+
tester.check();
82+
}
83+
7384
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
7485
sampler_tester tester(probs, probs_expected);
7586

@@ -277,6 +288,9 @@ int main(void) {
277288
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
278289
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
279290

291+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
292+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
293+
280294
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
281295
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
282296
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);

0 commit comments

Comments
 (0)