File tree Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -82,8 +82,6 @@ struct ET_EXPERIMENTAL Stats {
8282 long aggregate_sampling_timer_start_timestamp = 0 ;
8383};
8484
85- static constexpr auto kTopp = 0 .9f ;
86-
8785inline std::string stats_to_json_string (const Stats& stats) {
8886 std::stringstream ss;
8987 ss << " {\" prompt_tokens\" :" << stats.num_prompt_tokens << " ,"
@@ -168,7 +166,6 @@ namespace executorch {
168166namespace llm {
169167// TODO(T197294990): Remove these deprecated aliases once all users have moved
170168// to the new `::executorch` namespaces.
171- using ::executorch::extension::llm::kTopp ;
172169using ::executorch::extension::llm::print_report;
173170using ::executorch::extension::llm::Stats;
174171} // namespace llm
Original file line number Diff line number Diff line change 3434
3535#include < executorch/extension/llm/sampler/sampler.h>
3636#include < algorithm>
37+ #include < ctime>
3738
3839namespace executorch {
3940namespace extension {
@@ -129,6 +130,12 @@ Sampler::Sampler(
129130 topp_(topp),
130131 rng_state_(rng_seed) {}
131132
133+ Sampler::Sampler (int vocab_size, float temperature)
134+ : vocab_size_(vocab_size),
135+ inv_temperature_(static_cast <bool >(temperature) ? 1.0f / temperature : 0),
136+ topp_(kTopp ),
137+ rng_state_(std::time(nullptr )) {}
138+
132139template <typename T>
133140static void softmax (T* x, int size) {
134141 // find max value (for numerical stability)
Original file line number Diff line number Diff line change @@ -26,6 +26,8 @@ namespace extension {
2626namespace llm {
2727// A simple llama2 sampler.
2828
29+ inline constexpr auto kTopp = 0 .9f ;
30+
2931template <typename T>
3032struct ET_EXPERIMENTAL ProbIndex {
3133 T prob;
@@ -40,6 +42,8 @@ class ET_EXPERIMENTAL Sampler {
4042 float topp,
4143 unsigned long long rng_seed);
4244
45+ Sampler (int32_t vocab_size, float temperature);
46+
4347 template <typename T>
4448 int32_t sample (T* logits);
4549
@@ -71,3 +75,9 @@ using ::executorch::extension::llm::ProbIndex;
7175using ::executorch::extension::llm::Sampler;
7276} // namespace executor
7377} // namespace torch
78+
79+ namespace executorch ::llm {
80+ // TODO(T197294990): Remove these deprecated aliases once all users have moved
81+ // to the new `::executorch::extension::llm` namespaces.
82+ using ::executorch::extension::llm::kTopp ;
83+ } // namespace executorch::llm
You can’t perform that action at this time.
0 commit comments