|
| 1 | +import torch as th |
1 | 2 | import torch.nn as nn |
2 | 3 | from scipy.stats import entropy |
3 | 4 |
|
@@ -32,51 +33,27 @@ def __call__(self, y_true, y_logits): |
32 | 33 | torch.Tensor: The aggregated entropy value(s) based on the specified |
33 | 34 | method ('mean', 'sum', or 'none'). |
34 | 35 | """ |
35 | | - entropy_values = entropy(y_logits, axis=1) |
| 36 | + y_pred = nn.Softmax(dim=1)(y_logits) |
| 37 | + entropy_values = entropy(y_pred, axis=1) |
36 | 38 | entropy_values = th.from_numpy(entropy_values) |
37 | 39 |
|
38 | 40 | # Fix numerical errors for perfect guesses |
39 | 41 | entropy_values[entropy_values == th.inf] = 0 |
40 | 42 | entropy_values = th.nan_to_num(entropy_values) |
41 | 43 |
|
42 | | - if self.averages == "mean": |
43 | | - entropy_values = th.mean(entropy_values) |
44 | | - elif self.averages == "sum": |
45 | | - entropy_values = th.sum(entropy_values) |
46 | | - elif self.averages == "none": |
47 | | - return entropy_values |
48 | | - |
49 | | - self.stored_entropy_values.append(entropy_values) |
50 | | - |
51 | | - return entropy_values |
| 44 | + for sample in entropy_values: |
| 45 | + self.stored_entropy_values.append(sample.item()) |
| 46 | + |
52 | 47 |
|
53 | 48 | def __returnmetric__(self): |
54 | 49 | if self.averages == "mean": |
55 | 50 | self.stored_entropy_values = th.mean(self.stored_entropy_values) |
56 | 51 | elif self.averages == "sum": |
57 | 52 | self.stored_entropy_values = th.sum(self.stored_entropy_values) |
58 | 53 | elif self.averages == "none": |
59 | | - return self.stored_entropy_values |
| 54 | + pass |
| 55 | + return self.stored_entropy_values |
60 | 56 |
|
61 | 57 | def __reset__(self): |
62 | 58 | self.stored_entropy_values = [] |
63 | 59 |
|
64 | | - |
65 | | -if __name__ == "__main__": |
66 | | - import torch as th |
67 | | - |
68 | | - metric = EntropyPrediction(averages="mean") |
69 | | - |
70 | | - true_lab = th.Tensor([0, 1, 1, 2, 4, 3]).reshape(6, 1) |
71 | | - pred_logits = th.nn.functional.one_hot(true_lab, 5) |
72 | | - |
73 | | - assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5 |
74 | | - |
75 | | - pred_logits = th.rand(6, 5) |
76 | | - metric2 = EntropyPrediction(averages="sum") |
77 | | - assert ( |
78 | | - th.abs( |
79 | | - th.sum(6 * metric(true_lab, pred_logits) - metric2(true_lab, pred_logits)) |
80 | | - ) |
81 | | - < 1e-5 |
82 | | - ) |
0 commit comments