Skip to content

Commit 15552a6

Browse files
committed
fix: clean-up
1 parent d51de0e commit 15552a6

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

code/ac/llama/Session.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,18 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
2020
return llama_batch_get_one(nonConstTokens, int32_t(tokens.size()));
2121
}
2222

23-
void fillLogits(TokenDataVector& out, llama_context* lctx) {
23+
TokenDataVector fillLogits(llama_context* lctx) {
2424
const auto* logits = llama_get_logits_ith(lctx, -1);
2525

2626
const auto* lmodel = llama_get_model(lctx);
2727
const int vocabSize = llama_vocab_n_tokens(llama_model_get_vocab(lmodel));
2828

29-
out.resize(vocabSize);
30-
29+
TokenDataVector result(vocabSize);
3130
for (llama_token id = 0; id < vocabSize; id++) {
32-
out[id] = {id, logits[id]};
31+
result[id] = {id, logits[id]};
3332
}
33+
34+
return result;
3435
}
3536
}
3637

@@ -180,18 +181,17 @@ Token Session::getToken() {
180181
return m_state.m_currToken;
181182
}
182183

183-
TokenDataVector Session::getSampledTokenData(int32_t topK, float /*topP*/) {
184+
TokenDataVector Session::getSampledTokenData(int32_t topK) {
184185
flushPendingState();
185186

186-
TokenDataVector tempData;
187-
fillLogits(tempData, m_ctx);
187+
TokenDataVector tempData = fillLogits(m_ctx);
188188

189189
std::sort(tempData.begin(), tempData.end(), [](const TokenData & a, const TokenData & b) {
190190
return a.logit > b.logit;
191191
});
192192

193193
TokenDataVector result;
194-
result.insert(result.end(), tempData.begin(), tempData.begin() + topK);;
194+
result.insert(result.end(), tempData.begin(), tempData.begin() + topK);
195195

196196
return result;
197197
}

code/ac/llama/Session.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Session {
3737
// main functions to interact with the model
3838
void pushPrompt(std::span<const Token> prompt, std::span<const Token> postfix = {});
3939
Token getToken();
40-
TokenDataVector getSampledTokenData(int32_t topK, float topP = 0.95f);
40+
TokenDataVector getSampledTokenData(int32_t topK);
4141
std::vector<uint8_t> getState();
4242
private:
4343
enum class Source {

0 commit comments

Comments
 (0)