Skip to content

Commit 95d1489

Browse files
committed
Added onto sphinx thingy
1 parent b4d6400 commit 95d1489

File tree

2 files changed

+22
-33
lines changed

2 files changed

+22
-33
lines changed

.github/Individual_pages/Magnus_page.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,17 @@ The EntropyPrediction class' main job is to take some inputs and return the Shan
3535
* __returnmetric__ : Returns the collected metric.
3636
* __reset__ : Removes all the stored values up until that point. Readies the instance for storing values from a new epoch.
3737

38-
The class is initialized with a single parameter called "averages". This is inspired from other PyTorch and NumPy implementations and controlls how values from different batches or within batches will be combined. The __init__ method checks the value of this argument with an assertion, which must be one of three string. We only allow "mean", "sum" and "none" as methods of combining the different entropy values.
39-
Furtherore, this method will also store the different Shannon Entropy values as we pass values into the __call__ method.
38+
The class is initialized with a single parameter called "averages". This is inspired from other PyTorch and NumPy implementations and controlls how values from different batches or within batches will be combined. The __init__ method checks the value of this argument with an assertion, which must be one of three string. We only allow "mean", "sum" and "none" as methods of combining the different entropy values. We'll come back to the specifics here.
39+
Furthermore, this method will also store the different Shannon Entropy values as we pass values into the __call__ method.
40+
41+
In __call__ we get both true labels and model logit scores for each sample in the batch as input. We're calculating Shannon Entropy, not KL-divergence, so the true labels aren't needed.
42+
With permission I've used the scipy implementation to calculate entropy here. We apply a softmax over the logit values, then calculate the Shannon Entropy, and make sure to remove any NaN or Inf values which might arise from a perfect guess/distribution.
43+
44+
Next we have the __returnmetric__ method which is used to retrive the stored metric. Here the averages argument comes into play.
45+
Depending on what has been chosen as the averaging metric when initializing the class, one of the following operations will be applied to the stored values:
46+
* Mean: Calculate the mean of the stored entropy values.
47+
* Sum: Sum the stored entropy values.
48+
* None: Do nothing with the stored entropy values.
49+
Then the value(s) are returned.
50+
51+
Lastly we have the __reset__ method which simply emptied the variable which stores the entropy values to prepare it for the next epoch.

utils/metrics/EntropyPred.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch as th
12
import torch.nn as nn
23
from scipy.stats import entropy
34

@@ -32,51 +33,27 @@ def __call__(self, y_true, y_logits):
3233
torch.Tensor: The aggregated entropy value(s) based on the specified
3334
method ('mean', 'sum', or 'none').
3435
"""
35-
entropy_values = entropy(y_logits, axis=1)
36+
y_pred = nn.Softmax(dim=1)(y_logits)
37+
entropy_values = entropy(y_pred, axis=1)
3638
entropy_values = th.from_numpy(entropy_values)
3739

3840
# Fix numerical errors for perfect guesses
3941
entropy_values[entropy_values == th.inf] = 0
4042
entropy_values = th.nan_to_num(entropy_values)
4143

42-
if self.averages == "mean":
43-
entropy_values = th.mean(entropy_values)
44-
elif self.averages == "sum":
45-
entropy_values = th.sum(entropy_values)
46-
elif self.averages == "none":
47-
return entropy_values
48-
49-
self.stored_entropy_values.append(entropy_values)
50-
51-
return entropy_values
44+
for sample in entropy_values:
45+
self.stored_entropy_values.append(sample.item())
46+
5247

5348
def __returnmetric__(self):
5449
if self.averages == "mean":
5550
self.stored_entropy_values = th.mean(self.stored_entropy_values)
5651
elif self.averages == "sum":
5752
self.stored_entropy_values = th.sum(self.stored_entropy_values)
5853
elif self.averages == "none":
59-
return self.stored_entropy_values
54+
pass
55+
return self.stored_entropy_values
6056

6157
def __reset__(self):
6258
self.stored_entropy_values = []
6359

64-
65-
if __name__ == "__main__":
66-
import torch as th
67-
68-
metric = EntropyPrediction(averages="mean")
69-
70-
true_lab = th.Tensor([0, 1, 1, 2, 4, 3]).reshape(6, 1)
71-
pred_logits = th.nn.functional.one_hot(true_lab, 5)
72-
73-
assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5
74-
75-
pred_logits = th.rand(6, 5)
76-
metric2 = EntropyPrediction(averages="sum")
77-
assert (
78-
th.abs(
79-
th.sum(6 * metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))
80-
)
81-
< 1e-5
82-
)

0 commit comments

Comments
 (0)