@@ -20,17 +20,18 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
2020 return llama_batch_get_one (nonConstTokens, int32_t (tokens.size ()));
2121}
2222
23- void fillLogits (TokenDataVector& out, llama_context* lctx) {
23+ TokenDataVector fillLogits (llama_context* lctx) {
2424 const auto * logits = llama_get_logits_ith (lctx, -1 );
2525
2626 const auto * lmodel = llama_get_model (lctx);
2727 const int vocabSize = llama_vocab_n_tokens (llama_model_get_vocab (lmodel));
2828
29- out.resize (vocabSize);
30-
29+ TokenDataVector result (vocabSize);
3130 for (llama_token id = 0 ; id < vocabSize; id++) {
32- out [id] = {id, logits[id]};
31+ result [id] = {id, logits[id]};
3332 }
33+
34+ return result;
3435}
3536}
3637
@@ -180,18 +181,17 @@ Token Session::getToken() {
180181 return m_state.m_currToken ;
181182}
182183
183- TokenDataVector Session::getSampledTokenData (int32_t topK, float /* topP */ ) {
184+ TokenDataVector Session::getSampledTokenData (int32_t topK) {
184185 flushPendingState ();
185186
186- TokenDataVector tempData;
187- fillLogits (tempData, m_ctx);
187+ TokenDataVector tempData = fillLogits (m_ctx);
188188
189189 std::sort (tempData.begin (), tempData.end (), [](const TokenData & a, const TokenData & b) {
190190 return a.logit > b.logit ;
191191 });
192192
193193 TokenDataVector result;
194- result.insert (result.end (), tempData.begin (), tempData.begin () + topK);;
194+ result.insert (result.end (), tempData.begin (), tempData.begin () + topK);
195195
196196 return result;
197197}
0 commit comments