11import torch as th
22import torch .nn as nn
3+ import numpy as np
34from scipy .stats import entropy
45
56
@@ -22,7 +23,7 @@ def __init__(self, averages: str = "mean"):
2223 self .averages = averages
2324 self .stored_entropy_values = []
2425
25- def __call__ (self , y_true , y_logits ):
26+ def __call__ (self , y_true : th . Tensor , y_logits : th . Tensor ):
2627 """
2728 Computes the Shannon Entropy of the predicted logits and stores the results.
2829 Args:
@@ -33,27 +34,44 @@ def __call__(self, y_true, y_logits):
3334 torch.Tensor: The aggregated entropy value(s) based on the specified
3435 method ('mean', 'sum', or 'none').
3536 """
37+
38+ assert len (y_logits .size ()) == 2 , f'y_logits shape: { y_logits .size ()} '
3639 y_pred = nn .Softmax (dim = 1 )(y_logits )
40+ print (f'y_pred: { y_pred } ' )
3741 entropy_values = entropy (y_pred , axis = 1 )
3842 entropy_values = th .from_numpy (entropy_values )
3943
4044 # Fix numerical errors for perfect guesses
4145 entropy_values [entropy_values == th .inf ] = 0
4246 entropy_values = th .nan_to_num (entropy_values )
43-
47+ print ( f'Entropy Values: { entropy_values } ' )
4448 for sample in entropy_values :
4549 self .stored_entropy_values .append (sample .item ())
46-
50+
4751
4852 def __returnmetric__ (self ):
53+ stored_entropy_values = th .from_numpy (np .asarray (self .stored_entropy_values ))
54+
4955 if self .averages == "mean" :
50- self . stored_entropy_values = th .mean (self . stored_entropy_values )
56+ stored_entropy_values = th .mean (stored_entropy_values )
5157 elif self .averages == "sum" :
52- self . stored_entropy_values = th .sum (self . stored_entropy_values )
58+ stored_entropy_values = th .sum (stored_entropy_values )
5359 elif self .averages == "none" :
5460 pass
55- return self . stored_entropy_values
61+ return stored_entropy_values
5662
5763 def __reset__ (self ):
5864 self .stored_entropy_values = []
5965
66+ if __name__ == '__main__' :
67+
68+ pred_logits = th .rand (6 , 5 )
69+ true_lab = th .rand (6 , 5 )
70+
71+ metric = EntropyPrediction (averages = "mean" )
72+ metric2 = EntropyPrediction (averages = "sum" )
73+
74+ # Test for averaging metric consistency
75+ metric (true_lab , pred_logits )
76+ metric2 (true_lab , pred_logits )
77+ assert (th .abs (th .sum (6 * metric .__returnmetric__ () - metric2 .__returnmetric__ ())) < 1e-5 )
0 commit comments