File tree Expand file tree Collapse file tree 2 files changed +29
-12
lines changed
Expand file tree Collapse file tree 2 files changed +29
-12
lines changed Original file line number Diff line number Diff line change @@ -107,8 +107,7 @@ def main():
107107 optimizer .step ()
108108 optimizer .zero_grad (set_to_none = True )
109109
110- preds = th .argmax (logits , dim = 1 )
111- metrics (y , preds )
110+ metrics (y , logits )
112111
113112 break
114113 print (metrics .accumulate ())
@@ -134,8 +133,7 @@ def main():
134133 optimizer .zero_grad (set_to_none = True )
135134 trainingloss .append (loss .item ())
136135
137- preds = th .argmax (logits , dim = 1 )
138- metrics (y , preds )
136+ metrics (y , logits )
139137
140138 wandb .log (metrics .accumulate (str_prefix = "Train " ))
141139 metrics .reset ()
@@ -150,8 +148,7 @@ def main():
150148 loss = criterion (logits , y )
151149 evalloss .append (loss .item ())
152150
153- preds = th .argmax (logits , dim = 1 )
154- metrics (y , preds )
151+ metrics (y , logits )
155152
156153 wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
157154 metrics .reset ()
Original file line number Diff line number Diff line change 1+ import numpy as np
12import torch .nn as nn
3+ from scipy .stats import entropy
24
35
46class EntropyPrediction (nn .Module ):
5- def __init__ (self ):
7+ def __init__ (self , averages : str = 'average' ):
8+ """
9+ Initializes the EntropyPrediction module.
10+ Args:
11+ averages (str): Specifies the method of aggregation for entropy values.
12+ Must be either 'average' or 'sum'.
13+ Raises:
14+ AssertionError: If the averages parameter is not 'average' or 'sum'.
15+ """
616 super ().__init__ ()
7-
17+
18+ assert averages == 'average' or averages == 'sum'
19+ self .averages = averages
20+ self .stored_entropy_values = []
21+
822 def __call__ (self , y_true , y_false_logits ):
9- return
10-
11- def __reset__ (self ):
12- pass
23+ """
24+ Computes the entropy between true labels and predicted logits, storing the results.
25+ Args:
26+ y_true: The true labels.
27+ y_false_logits: The predicted logits.
28+ Side Effects:
29+ Appends the computed entropy values to the stored_entropy_values list.
30+ """
31+ entropy_values = entropy (y_true , qk = y_false_logits )
32+ return entropy_values
You can’t perform that action at this time.
0 commit comments