Skip to content

Commit d61e0cd

Browse files
committed
feat: add logitSimilarity method to compare logit values between token data
1 parent 6780666 commit d61e0cd

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

code/ac/llama/LogitComparer.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@ float LogitComparer::cosineDistance(const TokenDataVector& data1, const TokenDat
6666
return 1.0f - (dot / (std::sqrt(normA) * std::sqrt(normB)));
6767
}
6868

69+
float LogitComparer::logitSimilarity(const TokenDataVector& data1, const TokenDataVector& data2) {
70+
float res = 0.0f;
71+
72+
assert(data1.size() == data2.size());
73+
std::unordered_map<int32_t, float> l_map, l2_map;
74+
75+
for (const auto& t : data1) l_map[t.token] = t.logit;
76+
for (const auto& t : data2) l2_map[t.token] = t.logit;
77+
78+
for (auto& t : data1) {
79+
if (l2_map.count(t.token)) {
80+
res += 1 - (std::abs(t.logit - l2_map[t.token]) / std::max(t.logit, l2_map[t.token]));
81+
} else {
82+
// Token not found in the second map
83+
// we should penalize the result
84+
// but we don't know how much
85+
// so we just add 0.0f for now, maybe it should be another value
86+
res += 0.0f; // Token not found in the second map
87+
}
88+
}
89+
90+
return res / data1.size();
91+
}
92+
6993
float LogitComparer::jsd(const std::unordered_map<Token, float>& probs1, const std::unordered_map<Token, float>& probs2) {
7094
std::unordered_map<Token, float> avg_dist;
7195
for (const auto& [token, p] : probs1) {

code/ac/llama/LogitComparer.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class LogitComparer {
1616

1717
static float cosineDistance(const TokenDataVector& data1, const TokenDataVector& data2);
1818

19+
static float logitSimilarity(const TokenDataVector& data1, const TokenDataVector& data2);
20+
1921
private:
2022
static float jsd(const std::unordered_map<Token, float>& logits1, const std::unordered_map<Token, float>& logits2);
2123
static float euclidean_distance_sq(std::span<const TokenData> tokens);

0 commit comments

Comments
 (0)