Skip to content

Commit 9726519

Browse files
committed
Update entropy_decoding.py
1 parent 129ac80 commit 9726519

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

optillm/entropy_decoding.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2727
return entropy, varentropy
2828

2929
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
30-
# attention_weights are already probabilities (post-softmax)
3130
attention_probs = attention_weights
3231

3332
# Calculate entropy
3433
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
3534

36-
# Calculate variance of entropy
37-
attn_varentropy = torch.var(attn_entropy, dim=-1)
35+
# Calculate variance of entropy with unbiased=False to avoid df issues
36+
# Also add a check for singleton dimensions
37+
if attn_entropy.size(-1) > 1:
38+
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
39+
else:
40+
attn_varentropy = torch.zeros_like(attn_entropy)
41+
3842
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
3943
torch.zeros_like(attn_varentropy),
4044
attn_varentropy)
4145

42-
# Calculate mean attention and agreement
46+
# Rest remains the same
4347
mean_attention = torch.mean(attention_probs, dim=1)
4448
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
4549

46-
# For interaction strength, we can use log probabilities to approximate the original scores
47-
# This maintains the relative relationships while providing a reasonable proxy for attention strength
4850
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
4951
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
5052

0 commit comments

Comments
 (0)