Skip to content

Commit f2a4b90

Browse files
committed
updated precision metric to comply with seilmast's changes in #77
1 parent d9be199 commit f2a4b90

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

CollaborativeCoding/metrics/precision.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
1818

1919
self.num_classes = num_classes
2020
self.macro_averaging = macro_averaging
21+
self.y_true = []
22+
self.y_pred = []
2123

2224
def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
2325
"""Compute precision of model
@@ -35,17 +37,11 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
3537
Precision score
3638
"""
3739
y_pred = logits.argmax(dim=-1)
38-
return (
39-
self._macro_avg_precision(y_true, y_pred)
40-
if self.macro_averaging
41-
else self._micro_avg_precision(y_true, y_pred)
42-
)
40+
41+
# Append to the class-global values
42+
self.y_true.append(y_true)
43+
self.y_pred.append(y_pred)
4344

44-
def accumulate(self):
45-
pass # TODO fill
46-
47-
def reset(self):
48-
pass # TODO fill
4945

5046
def _micro_avg_precision(
5147
self, y_true: torch.tensor, y_pred: torch.tensor
@@ -64,7 +60,6 @@ def _micro_avg_precision(
6460
torch.tensor
6561
Micro-averaged precision
6662
"""
67-
print(y_true.shape)
6863
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
6964
1, y_true.unsqueeze(1), 1
7065
)
@@ -103,6 +98,27 @@ def _macro_avg_precision(
10398
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
10499

105100
return torch.nanmean(tp / (tp + fp))
101+
102+
def __returnmetric__(self):
103+
if self.y_true == [] and self.y_pred == []:
104+
return []
105+
elif self.y_true == [] or self.y_pred == []:
106+
raise ValueError("y_true or y_pred is empty.")
107+
self.y_true = torch.cat(self.y_true)
108+
self.y_pred = torch.cat(self.y_pred)
109+
110+
return self._macro_avg_precision(self.y_true, self.y_pred) if self.macro_averaging else self._micro_avg_precision(self.y_true, self.y_pred)
111+
112+
def __reset__(self):
113+
"""Resets the class-global lists of true and predicted values to empty lists.
114+
115+
Returns
116+
-------
117+
None
118+
Returns None
119+
"""
120+
self.y_true = self.y_pred = []
121+
return None
106122

107123

108124
if __name__ == "__main__":

0 commit comments

Comments
 (0)