Skip to content

Commit 6780666

Browse files
committed
refactor: return JSD comparison seems to be more suitable
-- however, this cannot tell which generation might be bad for now.
1 parent f650d88 commit 6780666

File tree

3 files changed

+68
-32
lines changed

3 files changed

+68
-32
lines changed

code/ac/llama/LogitComparer.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,43 @@ namespace ac::llama {
1414
// - If at least 80% of the tokens are the same, we consider them equal
1515
// 3. Compare the Jensen-Shannon divergence of the probabilities
1616
// - If the divergence is less than the treshold, we consider them equal
17-
float LogitComparer::compare(const TokenDataVector& data1, const TokenDataVector& data2) {
18-
// const auto minSize = std::min(data1.size(), data2.size());
19-
// float distance1 = euclidean_distance_sq({data1.data(), minSize});
20-
// float distance2 = euclidean_distance_sq({data2.data(), minSize});
21-
22-
// float relative_threshold = 0.02f; // 2% difference allowed
23-
// float res = std::fabs(distance1 - distance2) / std::max(distance1, distance2);
24-
// if (res > relative_threshold) {
25-
// return false;
26-
// }
17+
bool LogitComparer::compare(const TokenDataVector& data1, const TokenDataVector& data2) {
18+
const auto minSize = std::min(data1.size(), data2.size());
19+
float distance1 = euclidean_distance_sq({data1.data(), minSize});
20+
float distance2 = euclidean_distance_sq({data2.data(), minSize});
21+
22+
float relative_threshold = 0.02f; // 2% difference allowed
23+
float res = std::fabs(distance1 - distance2) / std::max(distance1, distance2);
24+
if (res > relative_threshold) {
25+
return false;
26+
}
2727

2828
std::unordered_map<int32_t, float> prob_map, prob_map2;
2929

3030
for (const auto& p : data1) prob_map[p.token] = p.prob;
3131
for (const auto& p : data2) prob_map2[p.token] = p.prob;
3232

3333
// Check if at least 80% of the tokens are the same
34-
// float matchingTokens = 0;
35-
// for (const auto& p : data1) {
36-
// if (prob_map2.count(p.token)) {
37-
// matchingTokens++;
38-
// }
39-
// }
40-
41-
// float matchingPercentage = matchingTokens / minSize;
42-
// if (matchingPercentage < 0.8f) {
43-
// return false;
44-
// }
34+
float matchingTokens = 0;
35+
for (const auto& p : data1) {
36+
if (prob_map2.count(p.token)) {
37+
matchingTokens++;
38+
}
39+
}
40+
41+
float matchingPercentage = matchingTokens / minSize;
42+
if (matchingPercentage < 0.8f) {
43+
return false;
44+
}
45+
46+
return jsd(prob_map, prob_map2) < 0.01f; // 1% divergence allowed
47+
}
48+
49+
float LogitComparer::JSD(const TokenDataVector& data1, const TokenDataVector& data2) {
50+
std::unordered_map<int32_t, float> prob_map, prob_map2;
51+
52+
for (const auto& p : data1) prob_map[p.token] = p.prob;
53+
for (const auto& p : data2) prob_map2[p.token] = p.prob;
4554

4655
return jsd(prob_map, prob_map2);
4756
}

code/ac/llama/LogitComparer.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace ac::llama {
1010

1111
class LogitComparer {
1212
public:
13-
static float compare(const TokenDataVector& data1, const TokenDataVector& data2);
13+
static bool compare(const TokenDataVector& data1, const TokenDataVector& data2);
14+
15+
static float JSD(const TokenDataVector& data1, const TokenDataVector& data2);
1416

1517
static float cosineDistance(const TokenDataVector& data1, const TokenDataVector& data2);
1618

example/e-verify.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ int main() try {
170170
std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
171171
// std::string modelGguf = "BgGPT-Gemma-2-2B-IT-v1.0.Q8_0.gguf";
172172
// std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
173-
std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
173+
std::string modelGguf2 = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
174+
// std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
174175

175176
Model m1(tmpFolder + modelGguf, {});
176177
Model m2(tmpFolder + modelGguf2, {});
@@ -183,9 +184,10 @@ int main() try {
183184
std::cout << "Models to compare:\n" << modelGguf << "\n" << modelGguf2 << "\n";
184185
std::cout << "Comparing...\n";
185186

187+
std::vector<float> jsdResults;
186188
for (int i = 0; i < 1; ++i) {
187189

188-
auto res = m1.generate(prompt, 1000);
190+
auto res = m1.generate(prompt, 100);
189191
std::cout << "Model 1 generated: " << res.result << "\n";
190192
std::string genPrompt = res.initalPrompt;
191193

@@ -207,6 +209,8 @@ int main() try {
207209
totalWeightedDist += weight * fakeDist;
208210
totalWeight += weight;
209211

212+
jsdResults.push_back(1);
213+
210214
std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
211215
continue;
212216
}
@@ -222,16 +226,24 @@ int main() try {
222226

223227
assert(res2.steps.size() == 1);
224228

225-
// Step 1: Compare logits
226-
float dist = ac::llama::LogitComparer::cosineDistance(step.data, res2.steps[0].data);
229+
{
230+
// Step 1: Compare logits
231+
float dist = ac::llama::LogitComparer::cosineDistance(step.data, res2.steps[0].data);
232+
233+
// Step 2: Calculate confidence weight
234+
float entropy = normalizedEntropy(step.data);
235+
float weight = 1.0f - entropy; // high confidence = high weight
227236

228-
// Step 2: Calculate confidence weight
229-
float entropy = normalizedEntropy(step.data);
230-
float weight = 1.0f - entropy; // high confidence = high weight
237+
// Step 3: Accumulate weighted distance
238+
totalWeightedDist += weight * dist;
239+
totalWeight += weight;
240+
}
241+
242+
{
243+
float jsd = ac::llama::LogitComparer::JSD(step.data, res2.steps[0].data);
244+
jsdResults.push_back(jsd);
245+
}
231246

232-
// Step 3: Accumulate weighted distance
233-
totalWeightedDist += weight * dist;
234-
totalWeight += weight;
235247
}
236248

237249
// Final step: Normalize
@@ -245,6 +257,19 @@ int main() try {
245257
float finalScore = (totalWeight > 0.0f) ? (totalWeightedDist / totalWeight) : 0.0f;
246258
std::cout << "Final weighted distance score: " << finalScore << "\n";
247259

260+
// Final score interpretation
261+
// average JSD score
262+
// 0.0 | Perfect match (identical predictions)
263+
// 0.0001 - 0.001 | Practically indistinguishable
264+
// 0.001 - 0.01 | Moderate variation, likely different versions/settings
265+
// 0.01 - 0.1 | Large differences, likely different models
266+
float jsdSum = 0.0f;
267+
for (const auto& jsd : jsdResults) {
268+
jsdSum += jsd;
269+
}
270+
float jsdAvg = jsdSum / jsdResults.size();
271+
std::cout << "Average JSD score: " << jsdAvg << "\n";
272+
248273
}
249274
std::cout << '\n';
250275

0 commit comments

Comments
 (0)