Skip to content

Commit a5e966d

Browse files
committed
Fixed massive error
1 parent 4e2c8dd commit a5e966d

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

utils/metrics/EntropyPred.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
1+
import numpy as np
12
import torch.nn as nn
3+
from scipy.stats import entropy
24

35

46
class EntropyPrediction(nn.Module):
5-
def __init__(self):
7+
def __init__(self, averages: str = 'average'):
8+
"""
9+
Initializes the EntropyPrediction module.
10+
Args:
11+
averages (str): Specifies the method of aggregation for entropy values.
12+
Must be either 'average' or 'sum'.
13+
Raises:
14+
AssertionError: If the averages parameter is not 'average' or 'sum'.
15+
"""
616
super().__init__()
7-
8-
def __call__(self, y_true, y_false_logits):
9-
return
10-
11-
def __reset__(self):
12-
pass
17+
18+
assert averages == 'average' or averages == 'sum'
19+
self.averages = averages
20+
self.stored_entropy_values = []
21+
22+
def __call__(self, y_true, y_pred_logits):
23+
"""
24+
Computes the entropy between true labels and predicted logits, storing the results.
25+
Args:
26+
y_true: The true labels.
27+
y_pred_logits: The predicted logits.
28+
Side Effects:
29+
Appends the computed entropy values to the stored_entropy_values list.
30+
"""
31+
entropy_values = entropy(y_true, qk=y_pred_logits)
32+
return np.mean(entropy_values)

0 commit comments

Comments
 (0)