Skip to content

Commit cb33cd2

Browse files
committed
refactor: remove unused extractTokenData method and improve getSampledTokenData logic
1 parent a270b2d commit cb33cd2

File tree

3 files changed

+43
-31
lines changed

3 files changed

+43
-31
lines changed

code/ac/llama/Sampler.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,23 +172,6 @@ Token Sampler::sample(llama_context* lctx, int idx, bool grammarFirst) {
172172
return cur.data[cur.selected].id;
173173
}
174174

175-
TokenDataVector Sampler::extractTokenData(llama_context* lctx) {
176-
auto chain = m_samplerChain.get();
177-
178-
auto cur = fillLogits(m_cur, lctx, -1);
179-
180-
llama_sampler_apply(chain, &cur);
181-
182-
TokenDataVector result(cur.size);
183-
184-
for (size_t i = 0; i < cur.size; i++)
185-
{
186-
result[i] = {cur.data[i].id, cur.data[i].logit, cur.data[i].p};
187-
}
188-
189-
return result;
190-
}
191-
192175
void Sampler::reset() {
193176
llama_sampler_reset(m_grammarSampler.get());
194177
llama_sampler_reset(m_samplerChain.get());

code/ac/llama/Sampler.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ class AC_LLAMA_EXPORT Sampler {
101101
// idx is optional for sampling from the logits of the ith token
102102
Token sample(llama_context* lctx, int idx = -1, bool grammarFirst = false);
103103

104-
TokenDataVector extractTokenData(llama_context* lctx);
105-
106104
// accept token as sampled
107105
// if acceptGrammar is true, the token is accepted both by the sampling chain and the grammar
108106
void accept(Token id, bool acceptGrammar);

code/ac/llama/Session.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,37 @@ llama_batch makeInputBatch(std::span<const Token> tokens) {
1919
auto nonConstTokens = const_cast<Token*>(tokens.data());
2020
return llama_batch_get_one(nonConstTokens, int32_t(tokens.size()));
2121
}
22+
23+
void fillLogits(TokenDataVector& out, llama_context* lctx) {
24+
const auto* logits = llama_get_logits_ith(lctx, -1);
25+
26+
const auto* lmodel = llama_get_model(lctx);
27+
const int vocabSize = llama_vocab_n_tokens(llama_model_get_vocab(lmodel));
28+
29+
out.resize(vocabSize);
30+
31+
for (llama_token id = 0; id < vocabSize; id++) {
32+
out[id] = {id, logits[id], 0.0f};
33+
}
34+
}
35+
36+
static void applySoftMax(TokenDataVector& data) {
37+
// Apply softmax to the logits
38+
// The vector should be sorted in descending order
39+
40+
float max_l = data[0].logit;
41+
float cum_sum = 0.0f;
42+
43+
for (size_t i = 0; i < data.size(); ++i) {
44+
float p = expf(data[i].logit - max_l);
45+
data[i].prob = p;
46+
cum_sum += p;
47+
}
48+
49+
for (size_t i = 0; i < data.size(); ++i) {
50+
data[i].prob /= cum_sum;
51+
}
52+
}
2253
}
2354

2455
Session::Session(Instance& instance, llama_context* ctx, InitParams params)
@@ -167,22 +198,22 @@ Token Session::getToken() {
167198
return m_state.m_currToken;
168199
}
169200

170-
TokenDataVector Session::getSampledTokenData(int32_t topK, float topP) {
201+
TokenDataVector Session::getSampledTokenData(int32_t topK, float /*topP*/) {
171202
flushPendingState();
172203

173-
Sampler::Params sParams = {
174-
.topK = topK,
175-
.topP = topP,
176-
.samplerSequence = {
177-
Sampler::SamplingType::Top_K,
178-
Sampler::SamplingType::Top_P,
179-
}
180-
};
181-
Sampler sampler(const_cast<Model&>(m_instance.model()), sParams);
204+
TokenDataVector tempData;
205+
fillLogits(tempData, m_ctx);
206+
207+
std::sort(tempData.begin(), tempData.end(), [](const TokenData & a, const TokenData & b) {
208+
return a.logit > b.logit;
209+
});
210+
211+
TokenDataVector result;
212+
result.insert(result.end(), tempData.begin(), tempData.begin() + topK);
182213

183-
auto logits = sampler.extractTokenData(m_ctx);
214+
applySoftMax(result);
184215

185-
return logits;
216+
return result;
186217
}
187218

188219
std::vector<uint8_t> Session::getState() {

0 commit comments

Comments
 (0)