Skip to content

Commit f650d88

Browse files
committed
refactor: change JSD with entropy to test the behavior
1 parent a8c0665 commit f650d88

File tree

5 files changed

+158
-37
lines changed

5 files changed

+158
-37
lines changed

code/ac/llama/LogitComparer.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44
#include "LogitComparer.hpp"
55
#include <cmath>
6+
#include <cassert>
67

78
namespace ac::llama {
89

@@ -13,36 +14,47 @@ namespace ac::llama {
1314
// - If at least 80% of the tokens are the same, we consider them equal
1415
// 3. Compare the Jensen-Shannon divergence of the probabilities
1516
// - If the divergence is less than the treshold, we consider them equal
16-
bool LogitComparer::compare(const TokenDataVector& data1, const TokenDataVector& data2) {
17-
const auto minSize = std::min(data1.size(), data2.size());
18-
float distance1 = euclidean_distance_sq({data1.data(), minSize});
19-
float distance2 = euclidean_distance_sq({data2.data(), minSize});
20-
21-
float relative_threshold = 0.02f; // 2% difference allowed
22-
float res = std::fabs(distance1 - distance2) / std::max(distance1, distance2);
23-
if (res > relative_threshold) {
24-
return false;
25-
}
17+
float LogitComparer::compare(const TokenDataVector& data1, const TokenDataVector& data2) {
18+
// const auto minSize = std::min(data1.size(), data2.size());
19+
// float distance1 = euclidean_distance_sq({data1.data(), minSize});
20+
// float distance2 = euclidean_distance_sq({data2.data(), minSize});
21+
22+
// float relative_threshold = 0.02f; // 2% difference allowed
23+
// float res = std::fabs(distance1 - distance2) / std::max(distance1, distance2);
24+
// if (res > relative_threshold) {
25+
// return false;
26+
// }
2627

2728
std::unordered_map<int32_t, float> prob_map, prob_map2;
2829

2930
for (const auto& p : data1) prob_map[p.token] = p.prob;
3031
for (const auto& p : data2) prob_map2[p.token] = p.prob;
3132

3233
// Check if at least 80% of the tokens are the same
33-
float matchingTokens = 0;
34-
for (const auto& p : data1) {
35-
if (prob_map2.count(p.token)) {
36-
matchingTokens++;
37-
}
38-
}
34+
// float matchingTokens = 0;
35+
// for (const auto& p : data1) {
36+
// if (prob_map2.count(p.token)) {
37+
// matchingTokens++;
38+
// }
39+
// }
3940

40-
float matchingPercentage = matchingTokens / minSize;
41-
if (matchingPercentage < 0.8f) {
42-
return false;
43-
}
41+
// float matchingPercentage = matchingTokens / minSize;
42+
// if (matchingPercentage < 0.8f) {
43+
// return false;
44+
// }
45+
46+
return jsd(prob_map, prob_map2);
47+
}
4448

45-
return jsd(prob_map, prob_map2) < 0.01;
49+
float LogitComparer::cosineDistance(const TokenDataVector& data1, const TokenDataVector& data2) {
50+
assert(data1.size() == data2.size());
51+
float dot = 0.0f, normA = 0.0f, normB = 0.0f;
52+
for (size_t i = 0; i < data1.size(); ++i) {
53+
dot += data1[i].logit * data2[i].logit;
54+
normA += data1[i].logit * data1[i].logit;
55+
normB += data2[i].logit * data2[i].logit;
56+
}
57+
return 1.0f - (dot / (std::sqrt(normA) * std::sqrt(normB)));
4658
}
4759

4860
float LogitComparer::jsd(const std::unordered_map<Token, float>& probs1, const std::unordered_map<Token, float>& probs2) {

code/ac/llama/LogitComparer.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace ac::llama {
1010

1111
class LogitComparer {
1212
public:
13-
static bool compare(const TokenDataVector& data1, const TokenDataVector& data2);
13+
static float compare(const TokenDataVector& data1, const TokenDataVector& data2);
14+
15+
static float cosineDistance(const TokenDataVector& data1, const TokenDataVector& data2);
1416

1517
private:
1618
static float jsd(const std::unordered_map<Token, float>& logits1, const std::unordered_map<Token, float>& logits2);

code/ac/llama/Vocab.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ bool Vocab::isEog(Token token) const noexcept {
3030
return llama_vocab_is_eog(m_lVocab, token);
3131
}
3232

33+
int32_t Vocab::nTokens() const noexcept {
34+
return llama_vocab_n_tokens(m_lVocab);
35+
}
36+
3337
std::vector<Token> Vocab::tokenize(std::string_view text, bool addSpecial, bool parseSpecial) const {
3438
int32_t numTokens = int32_t(text.length()) + 2 * addSpecial; // optimistic max
3539
std::vector<Token> ret(numTokens);

code/ac/llama/Vocab.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class AC_LLAMA_EXPORT Vocab {
2323
Token decoderStartToken() const noexcept; // fallback to bos if not available
2424

2525
bool isEog(Token token) const noexcept;
26+
int32_t nTokens() const noexcept;
2627

2728
std::string tokenToString(Token token, bool special = true) const;
2829

example/e-verify.cpp

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class Model {
3939
m_instance.reset(new ac::llama::Instance(*m_model, {
4040
.ctxSize = 2048,
4141
}));
42+
m_session = &m_instance->startSession({});
43+
m_session->setInitialPrompt({}); // empty prompt
44+
}
45+
46+
~Model() {
47+
m_instance->stopSession();
4248
}
4349

4450
struct GenerationResult {
@@ -48,10 +54,27 @@ class Model {
4854
};
4955

5056
GenerationResult generate(std::string prompt, uint32_t maxTokens) {
51-
m_session = &m_instance->startSession({});
52-
5357
auto promptTokens = m_model->vocab().tokenize(prompt, true, true);
54-
m_session->setInitialPrompt(promptTokens);
58+
return generate_impl(promptTokens, maxTokens);
59+
}
60+
61+
GenerationResult generate(std::span<ac::llama::Token> prompt, uint32_t maxTokens) {
62+
return generate_impl(prompt, maxTokens);
63+
}
64+
65+
std::vector<ac::llama::Token> tokenize(std::string prompt) {
66+
return m_model->vocab().tokenize(prompt, true, true);
67+
}
68+
69+
bool tokenExists(ac::llama::Token token) {
70+
return m_model->vocab().nTokens() > token;
71+
}
72+
73+
private:
74+
GenerationResult generate_impl(std::span<ac::llama::Token> promptTokens, uint32_t maxTokens) {
75+
if (!promptTokens.empty()) {
76+
m_session->pushPrompt(promptTokens, {});
77+
}
5578

5679
constexpr int32_t topK = 10;
5780
auto data = m_session->getSampledTokenData(topK);
@@ -85,13 +108,15 @@ class Model {
85108
});
86109
}
87110

88-
m_instance->stopSession();
89-
m_session = nullptr;
111+
std::string initialPrompt = "";
112+
for (size_t i = 0; i < promptTokens.size(); i++){
113+
initialPrompt += m_model->vocab().tokenToString(promptTokens[i], false);
114+
}
90115

91116
return {
92-
.initalPrompt = prompt,
93-
.result = result,
94-
.steps = genSteps
117+
.initalPrompt = std::move(initialPrompt),
118+
.result = std::move(result),
119+
.steps = std::move(genSteps)
95120
};
96121
}
97122

@@ -101,6 +126,35 @@ class Model {
101126
ac::llama::Session* m_session;
102127
};
103128

129+
// -- Helper function to compute normalized entropy --
130+
float normalizedEntropy(const ac::llama::TokenDataVector& data) {
131+
std::vector<float> probs(data.size());
132+
float sum = 0.0f;
133+
134+
// Calculate softmax probabilities
135+
for (auto& val : data) {
136+
sum += std::exp(val.logit);
137+
}
138+
for (size_t i = 0; i < data.size(); ++i) {
139+
probs[i] = std::exp(data[i].logit) / sum;
140+
}
141+
142+
// Calculate entropy
143+
float entropy = 0.0f;
144+
for (float p : probs) {
145+
if (p > 0.0f) {
146+
entropy -= p * std::log(p);
147+
}
148+
}
149+
150+
// Normalize entropy by maximum possible entropy (log(number of classes))
151+
float maxEntropy = std::log(float(probs.size()));
152+
return entropy / maxEntropy;
153+
}
154+
155+
156+
157+
104158
int main() try {
105159
ac::jalog::Instance jl;
106160
jl.setup().add<ac::jalog::sinks::ColorSink>();
@@ -112,7 +166,10 @@ int main() try {
112166

113167
// load model
114168
std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR "/../../../tmp/";
115-
std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
169+
// std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
170+
std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
171+
// std::string modelGguf = "BgGPT-Gemma-2-2B-IT-v1.0.Q8_0.gguf";
172+
// std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
116173
std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
117174

118175
Model m1(tmpFolder + modelGguf, {});
@@ -128,21 +185,66 @@ int main() try {
128185

129186
for (int i = 0; i < 1; ++i) {
130187

131-
auto res = m1.generate(prompt, 100);
188+
auto res = m1.generate(prompt, 1000);
132189
std::cout << "Model 1 generated: " << res.result << "\n";
133190
std::string genPrompt = res.initalPrompt;
191+
192+
auto genPromptTokens = m2.tokenize(genPrompt);
193+
194+
float totalWeightedDist = 0.0f;
195+
float totalWeight = 0.0f;
196+
134197
for (size_t i = 0; i < res.steps.size(); i++) {
135198
auto& step = res.steps[i];
136199
if (i > 0) {
137-
genPrompt += step.tokenStr;
200+
if (m2.tokenExists(step.token)) {
201+
genPromptTokens.push_back(step.token);
202+
}
203+
else {
204+
// Instead of skipping, penalize fully
205+
float fakeDist = 1.0f; // Maximum possible distance
206+
float weight = 1.0f; // Assume maximum confidence since we can't know entropy
207+
totalWeightedDist += weight * fakeDist;
208+
totalWeight += weight;
209+
210+
std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
211+
continue;
212+
}
138213
}
139-
auto res2 = m2.generate(genPrompt, 0);
140-
assert(res2.steps.size() == 1);
141214

142-
if (ac::llama::LogitComparer::compare(step.data, res2.steps[0].data)) {
143-
std::cout << "Models are the same. Generated str by now:\n" << genPrompt << "\n\n";
215+
Model::GenerationResult res2;
216+
if (i == 0) {
217+
res2 = m2.generate(genPromptTokens, 0);
218+
} else {
219+
std::vector<ac::llama::Token> token{step.token};
220+
res2 = m2.generate(token, 0);
144221
}
222+
223+
assert(res2.steps.size() == 1);
224+
225+
// Step 1: Compare logits
226+
float dist = ac::llama::LogitComparer::cosineDistance(step.data, res2.steps[0].data);
227+
228+
// Step 2: Calculate confidence weight
229+
float entropy = normalizedEntropy(step.data);
230+
float weight = 1.0f - entropy; // high confidence = high weight
231+
232+
// Step 3: Accumulate weighted distance
233+
totalWeightedDist += weight * dist;
234+
totalWeight += weight;
145235
}
236+
237+
// Final step: Normalize
238+
239+
// Score range | Interpretation
240+
// 0.0 | Perfect match (identical predictions)
241+
// 0.0001 - 0.001 | Practically indistinguishable
242+
// 0.001 - 0.01 | Very close, slight variation
243+
// 0.01 - 0.1 | Moderate variation, likely different versions/settings
244+
// 0.1 - 1.0 | Large differences, likely different models
245+
float finalScore = (totalWeight > 0.0f) ? (totalWeightedDist / totalWeight) : 0.0f;
246+
std::cout << "Final weighted distance score: " << finalScore << "\n";
247+
146248
}
147249
std::cout << '\n';
148250

0 commit comments

Comments
 (0)