Skip to content

Commit 0097de5

Browse files
authored
improve performance by actually applying nsigma's masking (LostRuins#1602)
merging, please report any issues.
1 parent 57ce374 commit 0097de5

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

gpttype_adapter.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,12 +1433,11 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) {
14331433
}
14341434

14351435
void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
1436-
14371436
if (nsigma <= 0.0f || cur_p->size <= 1) {
14381437
return;
14391438
}
14401439
// find max logit and calculate mean
1441-
float nsigmax = cur_p->data[0].logit;
1440+
float nsigmax = cur_p->data[0].logit;
14421441
float logits_sum = 0;
14431442
for (size_t i = 0; i < cur_p->size; ++i) {
14441443
if (cur_p->data[i].logit > nsigmax) {
@@ -1456,11 +1455,10 @@ void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
14561455
float nsigstd = sqrt(nsigacc / cur_p->size);
14571456

14581457
//apply mask
1459-
for (size_t i = 0; i < cur_p->size; ++i) {
1460-
if (cur_p->data[i].logit < nsigmax - (nsigma * nsigstd)) {
1461-
cur_p->data[i].logit -= 999.0f;
1462-
}
1463-
}
1458+
auto last = std::remove_if(cur_p->data, cur_p->data + cur_p->size,
1459+
[&](auto & tk) { return tk.logit < nsigmax - (nsigma * nsigstd); });
1460+
cur_p->size = last - cur_p->data;
1461+
14641462
sample_softmax(cur_p);
14651463
}
14661464

0 commit comments

Comments
 (0)