@@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
2727 return entropy , varentropy
2828
2929def calculate_attention_metrics (attention_weights : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30- # attention_weights are already probabilities (post-softmax)
3130 attention_probs = attention_weights
3231
3332 # Calculate entropy
3433 attn_entropy = - torch .sum (attention_probs * torch .log2 (torch .clamp (attention_probs , 1e-10 , 1.0 )), dim = - 1 )
3534
36- # Calculate variance of entropy
37- attn_varentropy = torch .var (attn_entropy , dim = - 1 )
35+ # Calculate variance of entropy with unbiased=False to avoid df issues
36+ # Also add a check for singleton dimensions
37+ if attn_entropy .size (- 1 ) > 1 :
38+ attn_varentropy = torch .var (attn_entropy , dim = - 1 , unbiased = False )
39+ else :
40+ attn_varentropy = torch .zeros_like (attn_entropy )
41+
3842 attn_varentropy = torch .where (torch .isnan (attn_varentropy ),
3943 torch .zeros_like (attn_varentropy ),
4044 attn_varentropy )
4145
42- # Calculate mean attention and agreement
46+ # Rest remains the same
4347 mean_attention = torch .mean (attention_probs , dim = 1 )
4448 agreement = torch .mean (torch .abs (attention_probs - mean_attention .unsqueeze (1 )), dim = (1 , 2 ))
4549
46- # For interaction strength, we can use log probabilities to approximate the original scores
47- # This maintains the relative relationships while providing a reasonable proxy for attention strength
4850 attention_scores_proxy = torch .log (torch .clamp (attention_probs , 1e-10 , 1.0 ))
4951 interaction_strength = torch .mean (torch .abs (attention_scores_proxy ), dim = (1 , 2 , 3 ))
5052
0 commit comments