Skip to content

Commit 78680a8

Browse files
committed
Fixed test_entropypred
1 parent 3df3063 commit 78680a8

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

tests/test_metrics.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,13 @@ def test_accuracy():
102102
def test_entropypred():
103103
import torch as th
104104

105-
metric = EntropyPrediction(averages="mean")
106-
107-
true_lab = th.Tensor([0, 1, 1, 2, 4, 3]).reshape(6, 1).type(th.LongTensor)
108-
pred_logits = th.nn.functional.one_hot(true_lab, 5)
109-
110-
# Test for log(0) errors and expected output
111-
assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5
112-
113105
pred_logits = th.rand(6, 5)
106+
true_lab = th.rand(6, 5)
107+
108+
metric = EntropyPrediction(averages="mean")
114109
metric2 = EntropyPrediction(averages="sum")
115-
110+
116111
# Test for averaging metric consistency
117-
assert (
118-
th.abs(
119-
th.sum(6 * metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))
120-
)
121-
< 1e-5
122-
)
112+
metric(true_lab, pred_logits)
113+
metric2(true_lab, pred_logits)
114+
assert (th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__())) < 1e-5)

utils/metrics/EntropyPred.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch as th
22
import torch.nn as nn
3+
import numpy as np
34
from scipy.stats import entropy
45

56

@@ -22,7 +23,7 @@ def __init__(self, averages: str = "mean"):
2223
self.averages = averages
2324
self.stored_entropy_values = []
2425

25-
def __call__(self, y_true, y_logits):
26+
def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
2627
"""
2728
Computes the Shannon Entropy of the predicted logits and stores the results.
2829
Args:
@@ -33,27 +34,44 @@ def __call__(self, y_true, y_logits):
3334
torch.Tensor: The aggregated entropy value(s) based on the specified
3435
method ('mean', 'sum', or 'none').
3536
"""
37+
38+
assert len(y_logits.size()) == 2, f'y_logits shape: {y_logits.size()}'
3639
y_pred = nn.Softmax(dim=1)(y_logits)
40+
print(f'y_pred: {y_pred}')
3741
entropy_values = entropy(y_pred, axis=1)
3842
entropy_values = th.from_numpy(entropy_values)
3943

4044
# Fix numerical errors for perfect guesses
4145
entropy_values[entropy_values == th.inf] = 0
4246
entropy_values = th.nan_to_num(entropy_values)
43-
47+
print(f'Entropy Values: {entropy_values}')
4448
for sample in entropy_values:
4549
self.stored_entropy_values.append(sample.item())
46-
50+
4751

4852
def __returnmetric__(self):
53+
stored_entropy_values = th.from_numpy(np.asarray(self.stored_entropy_values))
54+
4955
if self.averages == "mean":
50-
self.stored_entropy_values = th.mean(self.stored_entropy_values)
56+
stored_entropy_values = th.mean(stored_entropy_values)
5157
elif self.averages == "sum":
52-
self.stored_entropy_values = th.sum(self.stored_entropy_values)
58+
stored_entropy_values = th.sum(stored_entropy_values)
5359
elif self.averages == "none":
5460
pass
55-
return self.stored_entropy_values
61+
return stored_entropy_values
5662

5763
def __reset__(self):
5864
self.stored_entropy_values = []
5965

66+
if __name__ == '__main__':
67+
68+
pred_logits = th.rand(6, 5)
69+
true_lab = th.rand(6, 5)
70+
71+
metric = EntropyPrediction(averages="mean")
72+
metric2 = EntropyPrediction(averages="sum")
73+
74+
# Test for averaging metric consistency
75+
metric(true_lab, pred_logits)
76+
metric2(true_lab, pred_logits)
77+
assert (th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__())) < 1e-5)

0 commit comments

Comments
 (0)