Skip to content

Commit 129ac80

Browse files
committed
Update entropy_decoding.py
update attention scores
1 parent 996e48d commit 129ac80

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

optillm/entropy_decoding.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,28 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2626
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis)
2727
return entropy, varentropy
2828

29-
def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]:
30-
# attention_probs = F.softmax(attention_scores, dim=-1)
31-
attention_probs = attention_scores
29+
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
30+
# attention_weights are already probabilities (post-softmax)
31+
attention_probs = attention_weights
32+
33+
# Calculate entropy
3234
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
35+
36+
# Calculate variance of entropy
3337
attn_varentropy = torch.var(attn_entropy, dim=-1)
38+
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
39+
torch.zeros_like(attn_varentropy),
40+
attn_varentropy)
3441

35-
attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
42+
# Calculate mean attention and agreement
3643
mean_attention = torch.mean(attention_probs, dim=1)
3744
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
38-
39-
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
40-
45+
46+
# For interaction strength, we can use log probabilities to approximate the original scores
47+
# This maintains the relative relationships while providing a reasonable proxy for attention strength
48+
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
49+
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
50+
4151
return {
4252
"attn_entropy": torch.mean(attn_entropy),
4353
"attn_varentropy": torch.mean(attn_varentropy),

0 commit comments

Comments
 (0)