File tree Expand file tree Collapse file tree 1 file changed +27
-7
lines changed
Expand file tree Collapse file tree 1 file changed +27
-7
lines changed Original file line number Diff line number Diff line change 1+ import numpy as np
12import torch .nn as nn
3+ from scipy .stats import entropy
24
35
46class EntropyPrediction (nn .Module ):
5- def __init__ (self ):
7+ def __init__ (self , averages : str = 'average' ):
8+ """
9+ Initializes the EntropyPrediction module.
10+ Args:
11+ averages (str): Specifies the method of aggregation for entropy values.
12+ Must be either 'average' or 'sum'.
13+ Raises:
14+ AssertionError: If the averages parameter is not 'average' or 'sum'.
15+ """
616 super ().__init__ ()
7-
8- def __call__ (self , y_true , y_false_logits ):
9- return
10-
11- def __reset__ (self ):
12- pass
17+
18+ assert averages == 'average' or averages == 'sum'
19+ self .averages = averages
20+ self .stored_entropy_values = []
21+
22+ def __call__ (self , y_true , y_pred_logits ):
23+ """
24+ Computes the entropy between true labels and predicted logits, storing the results.
25+ Args:
26+ y_true: The true labels.
27+ y_pred_logits: The predicted logits.
28+ Side Effects:
29+ Appends the computed entropy values to the stored_entropy_values list.
30+ """
31+ entropy_values = entropy (y_true , qk = y_pred_logits )
32+ return np .mean (entropy_values )
You can’t perform that action at this time.
0 commit comments