Skip to content

Commit c9e7714

Browse files
authored
improve sampling time
Differential Revision: D60742125 Pull Request resolved: #4644
1 parent 7f34796 commit c9e7714

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

extension/llm/sampler/sampler.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
*/
3434

3535
#include <executorch/extension/llm/sampler/sampler.h>
36+
#include <algorithm>
3637

3738
namespace torch {
3839
namespace executor {
@@ -66,18 +67,6 @@ int32_t Sampler::sample_mult(T* probabilities, float coin) {
6667
return vocab_size_ - 1; // in case of rounding errors
6768
}
6869

69-
template <typename T>
70-
static int32_t compare(const void* a, const void* b) {
71-
ProbIndex<T>* a_ = (ProbIndex<T>*)a;
72-
ProbIndex<T>* b_ = (ProbIndex<T>*)b;
73-
if (a_->prob > b_->prob) {
74-
return -1;
75-
} else if (a_->prob < b_->prob) {
76-
return 1;
77-
}
78-
return 0;
79-
}
80-
8170
template <typename T>
8271
int32_t Sampler::sample_topp(T* probabilities, float coin) {
8372
// top-p sampling (or "nucleus sampling") samples from the smallest set of
@@ -100,7 +89,11 @@ int32_t Sampler::sample_topp(T* probabilities, float coin) {
10089
n0++;
10190
}
10291
}
103-
qsort(probindex.get(), n0, sizeof(ProbIndex<T>), compare<T>);
92+
93+
auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
94+
return a.prob > b.prob;
95+
};
96+
std::sort(probindex.get(), probindex.get() + n0, compare);
10497

10598
// truncate the list where cumulative probability exceeds topp
10699
T cumulative_prob = 0;

0 commit comments

Comments
 (0)