@@ -26,18 +26,28 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2626 varentropy = torch .sum (probs * (log_probs / LN_2 + entropy .unsqueeze (- 1 ))** 2 , dim = axis )
2727 return entropy , varentropy
2828
29- def calculate_attention_metrics (attention_scores : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30- # attention_probs = F.softmax(attention_scores, dim=-1)
31- attention_probs = attention_scores
29+ def calculate_attention_metrics (attention_weights : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30+ # attention_weights are already probabilities (post-softmax)
31+ attention_probs = attention_weights
32+
33+ # Calculate entropy
3234 attn_entropy = - torch .sum (attention_probs * torch .log2 (torch .clamp (attention_probs , 1e-10 , 1.0 )), dim = - 1 )
35+
36+ # Calculate variance of entropy
3337 attn_varentropy = torch .var (attn_entropy , dim = - 1 )
38+ attn_varentropy = torch .where (torch .isnan (attn_varentropy ),
39+ torch .zeros_like (attn_varentropy ),
40+ attn_varentropy )
3441
35- attn_varentropy = torch . where ( torch . isnan ( attn_varentropy ), torch . zeros_like ( attn_varentropy ), attn_varentropy )
42+ # Calculate mean attention and agreement
3643 mean_attention = torch .mean (attention_probs , dim = 1 )
3744 agreement = torch .mean (torch .abs (attention_probs - mean_attention .unsqueeze (1 )), dim = (1 , 2 ))
38-
39- interaction_strength = torch .mean (torch .abs (attention_scores ), dim = (1 , 2 , 3 ))
40-
45+
46+ # For interaction strength, we can use log probabilities to approximate the original scores
47+ # This maintains the relative relationships while providing a reasonable proxy for attention strength
48+ attention_scores_proxy = torch .log (torch .clamp (attention_probs , 1e-10 , 1.0 ))
49+ interaction_strength = torch .mean (torch .abs (attention_scores_proxy ), dim = (1 , 2 , 3 ))
50+
4151 return {
4252 "attn_entropy" : torch .mean (attn_entropy ),
4353 "attn_varentropy" : torch .mean (attn_varentropy ),
0 commit comments