Skip to content

Commit cdd5a4f

Browse files
committed
Merge branch 'johan/devbranch' into johan/micromacro
2 parents a35e6ea + 4c3dc32 commit cdd5a4f

File tree

3 files changed

+77
-16
lines changed

3 files changed

+77
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"sphinx-rtd-theme>=3.0.2",
2222
"torch>=2.6.0",
2323
"torchvision>=0.21.0",
24+
"tqdm>=4.67.1",
2425
]
2526
[tool.isort]
2627
profile = "black"

utils/metrics/precision.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,95 @@ class Precision(nn.Module):
1313
----------
1414
num_classes : int
1515
Number of classes in the dataset.
16-
use_mean : bool
17-
Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives.
16+
micro_averaging : bool
17+
Wheter to compute the micro or macro precision (default False)
1818
"""
1919

20-
def __init__(self, num_classes: int, use_mean: bool = True):
20+
def __init__(self, num_classes: int, micro_averaging: bool = False):
2121
super().__init__()
2222

2323
self.num_classes = num_classes
24-
self.use_mean = use_mean
24+
self._micro_averaging = micro_averaging
2525

2626
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
27-
"""Calculates the precision score given number of classes and the true and predicted labels.
27+
"""Compute precision of model
2828
2929
Parameters
3030
----------
3131
y_true : torch.tensor
32-
true labels
32+
True labels
3333
y_pred : torch.tensor
34-
predicted labels
34+
Predicted labels
3535
3636
Returns
3737
-------
3838
torch.tensor
39-
precision score
39+
Precision score
40+
"""
41+
return (
42+
self._micro_avg_precision(y_true, y_pred)
43+
if self.micro_averaging
44+
else self._macro_avg_precision(y_true, y_pred)
45+
)
46+
47+
def _micro_avg_precision(
48+
self, y_true: torch.tensor, y_pred: torch.tensor
49+
) -> torch.tensor:
50+
"""Compute micro-average precision by first calculating true/false positive across all classes and then find the precision.
51+
52+
Parameters
53+
----------
54+
y_true : torch.tensor
55+
True labels
56+
y_pred : torch.tensor
57+
Predicted labels
58+
59+
Returns
60+
-------
61+
torch.tensor
62+
Micro-averaged precision
4063
"""
41-
# One-hot encode the target tensor
4264
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
4365
1, y_true.unsqueeze(1), 1
4466
)
4567
pred_oh = torch.zeros(y_pred.size(0), self.num_classes).scatter_(
4668
1, y_pred.unsqueeze(1), 1
4769
)
70+
tp = torch.sum(true_oh * pred_oh)
71+
fp = torch.sum(~true_oh[pred_oh.bool()].bool())
4872

49-
if self.use_mean:
50-
tp = torch.sum(true_oh * pred_oh, 0)
51-
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
73+
return torch.nanmean(tp / (tp + fp))
74+
75+
def _macro_avg_precision(
76+
self, y_true: torch.tensor, y_pred: torch.tensor
77+
) -> torch.tensor:
78+
"""Compute macro-average precision by finding true/false positives of each class separately then averaging across all classes.
5279
53-
else:
54-
tp = torch.sum(true_oh * pred_oh)
55-
fp = torch.sum(~true_oh[pred_oh.bool()].bool())
80+
Parameters
81+
----------
82+
y_true : torch.tensor
83+
True labels
84+
y_pred : torch.tensor
85+
Predicted labels
86+
87+
Returns
88+
-------
89+
torch.tensor
90+
Macro-averaged precision
91+
"""
92+
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
93+
1, y_true.unsqueeze(1), 1
94+
)
95+
pred_oh = torch.zeros(y_pred.size(0), self.num_classes).scatter_(
96+
1, y_pred.unsqueeze(1), 1
97+
)
98+
tp = torch.sum(true_oh * pred_oh, 0)
99+
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
56100

57101
return torch.nanmean(tp / (tp + fp))
58102

59103

60104
if __name__ == "__main__":
61-
pass
105+
print(
106+
"Congratulations, you succesfully ran the Precision metric class. You should be proud of this marvelous achievement!"
107+
)

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)