File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ Sampler::Sampler(
131131 float topp,
132132 unsigned long long rng_seed)
133133 : vocab_size_(vocab_size),
134- temperature_ ( temperature),
134+ inv_temperature_ ( static_cast < bool >( temperature) ? 1.0f / temperature : 0 ),
135135 topp_(topp),
136136 rng_state_(rng_seed) {}
137137
@@ -172,13 +172,13 @@ template <typename T>
172172int32_t Sampler::sample (T* logits) {
173173 // sample the token given the logits and some hyperparameters
174174 int next;
175- if (temperature_ == 0 .0f ) {
175+ if (inv_temperature_ == 0 .0f ) {
176176 // greedy argmax sampling: take the token with the highest probability
177177 next = sample_argmax (logits);
178178 } else {
179179 // apply the temperature to the logits
180180 for (int q = 0 ; q < vocab_size_; q++) {
181- logits[q] /= temperature_ ;
181+ logits[q] *= inv_temperature_ ;
182182 }
183183 // apply softmax to the logits to get the probabilities for next token
184184 softmax (logits, vocab_size_);
Original file line number Diff line number Diff line change @@ -51,7 +51,8 @@ class Sampler {
5151
5252 private:
5353 int32_t vocab_size_;
54- float temperature_;
54+ // reciprocal of temperature, or 0 if temperature == 0.
55+ float inv_temperature_;
5556 float topp_;
5657 unsigned long long rng_state_;
5758};
You can’t perform that action at this time.
0 commit comments