Skip to content

Commit fb5b5bb

Browse files
committed
feat: improve similarity logic,
-- do not rely on probs, only logits -- adjust penalization - higher logit == more significant
1 parent 0989c5e commit fb5b5bb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

code/ac/llama/LogitComparer.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,20 @@ float LogitComparer::logitSimilarity(const TokenDataVector& data1, const TokenDa
7575
for (const auto& t : data1) l_map[t.token] = t.logit;
7676
for (const auto& t : data2) l2_map[t.token] = t.logit;
7777

78+
float weightedSimSum = 0.0f;
79+
float totalWeight = 0.0f;
7880
for (auto& t : data1) {
81+
float weight = t.prob;
82+
float sim = 0.0f;
7983
if (l2_map.count(t.token)) {
80-
res += 1 - (std::abs(t.logit - l2_map[t.token]) / std::max(t.logit, l2_map[t.token]));
81-
} else {
82-
// Token not found in the second map
83-
// we should penalize the result
84-
// but we don't know how much
85-
// so we just add 0.0f for now, maybe it should be another value
86-
res += 0.0f; // Token not found in the second map
84+
sim = 1 - (std::abs(t.logit - l2_map[t.token]) / std::max(t.logit, l2_map[t.token]));
8785
}
86+
87+
weightedSimSum += weight * sim;
88+
totalWeight += weight;
8889
}
8990

90-
return res / data1.size();
91+
return totalWeight > 0.0f ? (weightedSimSum / totalWeight) : 0.0f;
9192
}
9293

9394
float LogitComparer::jsd(const std::unordered_map<Token, float>& probs1, const std::unordered_map<Token, float>& probs2) {

0 commit comments

Comments
 (0)