Skip to content

Commit a42bb83

Browse files
committed
Updated doc for precision
1 parent 65fedd8 commit a42bb83

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

CollaborativeCoding/metrics/precision.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55

66
class Precision(nn.Module):
7-
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
7+
"""Metric module for precision. Can calculate both the micro- and macro-averaged precision.
88
99
Parameters
1010
----------
1111
num_classes : int
1212
Number of classes in the dataset.
1313
micro_averaging : bool
14-
Wheter to compute the micro or macro precision (default False)
14+
Performs micro-averaging if True, otherwise macro-averaging.
1515
"""
1616

1717
def __init__(self, num_classes: int, macro_averaging: bool = False):
@@ -23,19 +23,15 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
2323
self.y_pred = []
2424

2525
def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
26-
"""Compute precision of model
26+
"""Add true and predicted values to the class-global lists.
2727
2828
Parameters
2929
----------
3030
y_true : torch.tensor
3131
True labels
32-
y_pred : torch.tensor
32+
logits : torch.tensor
3333
Predicted labels
3434
35-
Returns
36-
-------
37-
torch.tensor
38-
Precision score
3935
"""
4036
y_pred = logits.argmax(dim=-1)
4137

@@ -100,6 +96,13 @@ def _macro_avg_precision(
10096
return torch.nanmean(tp / (tp + fp))
10197

10298
def __returnmetric__(self):
99+
"""Return the micro- or macro-averaged precision.
100+
101+
Returns
102+
-------
103+
torch.tensor
104+
Micro- or macro-averaged precision
105+
"""
103106
if self.y_true == [] and self.y_pred == []:
104107
return np.nan
105108
elif self.y_true == [] or self.y_pred == []:

0 commit comments

Comments
 (0)