@@ -19,6 +19,37 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
1919 auto nonConstTokens = const_cast <Token*>(tokens.data ());
2020 return llama_batch_get_one (nonConstTokens, int32_t (tokens.size ()));
2121}
22+
23+ void fillLogits (TokenDataVector& out, llama_context* lctx) {
24+ const auto * logits = llama_get_logits_ith (lctx, -1 );
25+
26+ const auto * lmodel = llama_get_model (lctx);
27+ const int vocabSize = llama_vocab_n_tokens (llama_model_get_vocab (lmodel));
28+
29+ out.resize (vocabSize);
30+
31+ for (llama_token id = 0 ; id < vocabSize; id++) {
32+ out[id] = {id, logits[id], 0 .0f };
33+ }
34+ }
35+
36+ static void applySoftMax (TokenDataVector& data) {
37+ // Apply softmax to the logits
38+ // The vector should be sorted in descending order
39+
40+ float max_l = data[0 ].logit ;
41+ float cum_sum = 0 .0f ;
42+
43+ for (size_t i = 0 ; i < data.size (); ++i) {
44+ float p = expf (data[i].logit - max_l);
45+ data[i].prob = p;
46+ cum_sum += p;
47+ }
48+
49+ for (size_t i = 0 ; i < data.size (); ++i) {
50+ data[i].prob /= cum_sum;
51+ }
52+ }
2253}
2354
2455Session::Session (Instance& instance, llama_context* ctx, InitParams params)
@@ -167,22 +198,22 @@ Token Session::getToken() {
167198 return m_state.m_currToken ;
168199}
169200
170- TokenDataVector Session::getSampledTokenData (int32_t topK, float topP) {
201+ TokenDataVector Session::getSampledTokenData (int32_t topK, float /* topP*/ ) {
171202 flushPendingState ();
172203
173- Sampler::Params sParams = {
174- . topK = topK,
175- . topP = topP,
176- . samplerSequence = {
177- Sampler::SamplingType::Top_K,
178- Sampler::SamplingType::Top_P,
179- }
180- } ;
181- Sampler sampler ( const_cast <Model&>(m_instance. model ()), sParams );
204+ TokenDataVector tempData;
205+ fillLogits (tempData, m_ctx);
206+
207+ std::sort (tempData. begin (), tempData. end (), []( const TokenData & a, const TokenData & b) {
208+ return a. logit > b. logit ;
209+ });
210+
211+ TokenDataVector result ;
212+ result. insert (result. end (), tempData. begin (), tempData. begin () + topK );
182213
183- auto logits = sampler. extractTokenData (m_ctx );
214+ applySoftMax (result );
184215
185- return logits ;
216+ return result ;
186217}
187218
188219std::vector<uint8_t > Session::getState () {
0 commit comments