Skip to content

Commit d51de0e

Browse files
committed
refactor: remove probs from the token vector
1 parent 2265994 commit d51de0e

File tree

6 files changed

+38
-60
lines changed

6 files changed

+38
-60
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,16 @@ struct LocalLlama {
347347

348348
std::vector<int32_t> tokens(tokenData.size());
349349
std::vector<float> logits(tokenData.size());
350-
std::vector<float> probs(tokenData.size());
351350
for (size_t i = 0; i < tokenData.size(); i++) {
352351
tokens[i] = tokenData[i].token;
353352
logits[i] = tokenData[i].logit;
354-
probs[i] = tokenData[i].prob;
355353
}
356354

357355
instance.stopSession();
358356

359357
co_await io.push(Frame_from(sc::StateGeneralInstance::OpGetTokenData{}, {
360358
.tokens = std::move(tokens),
361359
.logits = std::move(logits),
362-
.probs = std::move(probs)
363360
}));
364361
}
365362

@@ -368,33 +365,23 @@ struct LocalLlama {
368365
IoEndpoint& io,
369366
const sc::StateGeneralInstance::OpCompareTokenData::Params& iparams) {
370367

371-
auto& l1 = iparams.logits1.value();
372-
auto& l2 = iparams.logits2.value();
373-
auto& p1 = iparams.probs1.value();
374-
auto& p2 = iparams.probs2.value();
375368
auto& t1 = iparams.tokens1.value();
376369
auto& t2 = iparams.tokens2.value();
377-
assert(l1.size() == t1.size() && l1.size() == p1.size());
378-
assert(l2.size() == t2.size() && l2.size() == p2.size());
370+
auto& l1 = iparams.logits1.value();
371+
auto& l2 = iparams.logits2.value();
372+
assert(l2.size() == t2.size());
373+
assert(l1.size() == t1.size());
379374

380375
ac::llama::TokenDataVector data1;
381-
data1.resize(t1.size());
376+
data1.reserve(t1.size());
382377
for (size_t i = 0; i < t1.size(); i++) {
383-
data1[i] = ac::llama::TokenData{
384-
.token = t1[i],
385-
.logit = l1[i],
386-
.prob = p1[i]
387-
};
378+
data1.emplace_back(ac::llama::TokenData{ t1[i], l1[i] });
388379
}
389380

390381
ac::llama::TokenDataVector data2;
391-
data2.resize(t2.size());
382+
data2.reserve(t2.size());
392383
for (size_t i = 0; i < t2.size(); i++) {
393-
data2[i] = ac::llama::TokenData{
394-
.token = t2[i],
395-
.logit = l2[i],
396-
.prob = p2[i]
397-
};
384+
data2.emplace_back(ac::llama::TokenData{ t2[i], l2[i] });
398385
}
399386

400387
co_await io.push(Frame_from(sc::StateGeneralInstance::OpCompareTokenData{}, {

ac-local-plugin/example/ep-run.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,13 @@ int main() try {
7878

7979
auto result2 = llama.call<schema::StateGeneralInstance::OpGetTokenData>({});
8080

81-
std::cout << "Token Data [0]: " << result2.tokens.value()[0] << ", " << result2.logits.value()[0] << ", " << result2.probs.value()[0] << std::endl;
81+
std::cout << "Token Data [0]: " << result2.tokens.value()[0] << ", " << result2.logits.value()[0] << std::endl;
8282

8383
auto result3 = llama.call<schema::StateGeneralInstance::OpCompareTokenData>({
8484
.tokens1 = result2.tokens,
8585
.logits1 = result2.logits,
86-
.probs1 = result2.probs,
8786
.tokens2 = result2.tokens,
88-
.logits2 = result2.logits,
89-
.probs2 = result2.probs
87+
.logits2 = result2.logits
9088
});
9189

9290
std::cout << "Token Data Compare: " << result3.equal.value() << std::endl;

ac-local-plugin/schema/ac/schema/LlamaCpp.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,11 @@ struct StateGeneralInstance {
168168
struct Return {
169169
Field<std::vector<int32_t>> tokens;
170170
Field<std::vector<float>> logits;
171-
Field<std::vector<float>> probs;
172171

173172
template <typename Visitor>
174173
void visitFields(Visitor& v) {
175174
v(tokens, "tokens", "Tokens in the context");
176175
v(logits, "logits", "Logits for the tokens");
177-
v(probs, "probs", "Probabilities for the tokens");
178176
}
179177
};
180178

@@ -188,19 +186,15 @@ struct StateGeneralInstance {
188186
struct Params {
189187
Field<std::vector<int32_t>> tokens1;
190188
Field<std::vector<float>> logits1;
191-
Field<std::vector<float>> probs1;
192189
Field<std::vector<int32_t>> tokens2;
193190
Field<std::vector<float>> logits2;
194-
Field<std::vector<float>> probs2;
195191

196192
template <typename Visitor>
197193
void visitFields(Visitor& v) {
198194
v(tokens1, "tokens1", "Tokens in the first set");
199195
v(logits1, "logits1", "Logits for the first set");
200-
v(probs1, "probs1", "Probabilities for the first set");
201196
v(tokens2, "tokens2", "Tokens in the second set");
202197
v(logits2, "logits2", "Logits for the second set");
203-
v(probs2, "probs2", "Probabilities for the second set");
204198
}
205199
};
206200

code/ac/llama/LogitComparer.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,31 @@
44
#include "LogitComparer.hpp"
55
#include <cmath>
66
#include <cassert>
7+
namespace {
8+
std::unordered_map<int32_t, float> softmax(const ac::llama::TokenDataVector& data) {
9+
std::unordered_map<int32_t, float> result(data.size());
10+
11+
// Step 1: Find max logit to subtract for numerical stability
12+
float maxLogit = data[0].logit;
13+
14+
// Step 2: Compute exp(logit - maxLogit) for each element
15+
float sumExp = 0.0f;
16+
for (size_t i = 0; i < data.size(); ++i) {
17+
float p = std::exp(data[i].logit - maxLogit);
18+
result[data[i].token] = p;
19+
sumExp += p;
20+
}
721

8-
namespace ac::llama {
22+
// Step 3: Normalize to get probabilities
23+
for (auto& val : result) {
24+
val.second /= sumExp;
25+
}
26+
27+
return result;
28+
}
29+
}
930

31+
namespace ac::llama {
1032
// We apply 3 step comparison
1133
// 1. Compare the euclidean distance of the logits
1234
// - If the distance is less than 2% of the max distance, we consider them equal
@@ -25,10 +47,8 @@ bool LogitComparer::compare(const TokenDataVector& data1, const TokenDataVector&
2547
return false;
2648
}
2749

28-
std::unordered_map<int32_t, float> prob_map, prob_map2;
29-
30-
for (const auto& p : data1) prob_map[p.token] = p.prob;
31-
for (const auto& p : data2) prob_map2[p.token] = p.prob;
50+
auto prob_map = softmax(data1);
51+
auto prob_map2 = softmax(data2);
3252

3353
// Check if at least 80% of the tokens are the same
3454
float matchingTokens = 0;
@@ -56,7 +76,7 @@ float LogitComparer::logitSimilarity(const TokenDataVector& data1, const TokenDa
5676
float weightedSimSum = 0.0f;
5777
float totalWeight = 0.0f;
5878
for (auto& t : data1) {
59-
float weight = t.prob;
79+
float weight = t.logit;
6080
float sim = 0.0f;
6181
if (l2_map.count(t.token)) {
6282
sim = 1 - (std::abs(t.logit - l2_map[t.token]) / std::max(t.logit, l2_map[t.token]));

code/ac/llama/Session.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,7 @@ void fillLogits(TokenDataVector& out, llama_context* lctx) {
2929
out.resize(vocabSize);
3030

3131
for (llama_token id = 0; id < vocabSize; id++) {
32-
out[id] = {id, logits[id], 0.0f};
33-
}
34-
}
35-
36-
static void applySoftMax(TokenDataVector& data) {
37-
// Apply softmax to the logits
38-
// The vector should be sorted in descending order
39-
40-
float max_l = data[0].logit;
41-
float cum_sum = 0.0f;
42-
43-
for (size_t i = 0; i < data.size(); ++i) {
44-
float p = expf(data[i].logit - max_l);
45-
data[i].prob = p;
46-
cum_sum += p;
47-
}
48-
49-
for (size_t i = 0; i < data.size(); ++i) {
50-
data[i].prob /= cum_sum;
32+
out[id] = {id, logits[id]};
5133
}
5234
}
5335
}
@@ -209,9 +191,7 @@ TokenDataVector Session::getSampledTokenData(int32_t topK, float /*topP*/) {
209191
});
210192

211193
TokenDataVector result;
212-
result.insert(result.end(), tempData.begin(), tempData.begin() + topK);
213-
214-
applySoftMax(result);
194+
result.insert(result.end(), tempData.begin(), tempData.begin() + topK);;
215195

216196
return result;
217197
}

code/ac/llama/Token.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ inline constexpr Token Token_Invalid = -1;
1212
struct TokenData {
1313
Token token;
1414
float logit;
15-
float prob;
1615
};
1716

1817
using TokenDataVector = std::vector<TokenData>;

0 commit comments

Comments
 (0)