Skip to content

Commit ba90d89

Browse files
committed
Updated precision metric with macro_averaging as argument
1 parent 0c16ba1 commit ba90d89

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

utils/metrics/precision.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ class Precision(nn.Module):
1313
Wheter to compute the micro or macro precision (default False)
1414
"""
1515

16-
def __init__(self, num_classes: int, micro_averaging: bool = False):
16+
def __init__(self, num_classes: int, macro_averaging: bool = False):
1717
super().__init__()
1818

1919
self.num_classes = num_classes
20-
self.micro_averaging = micro_averaging
20+
self.macro_averaging = macro_averaging
2121

2222
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
2323
"""Compute precision of model
@@ -35,9 +35,9 @@ def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
3535
Precision score
3636
"""
3737
return (
38-
self._micro_avg_precision(y_true, y_pred)
39-
if self.micro_averaging
40-
else self._macro_avg_precision(y_true, y_pred)
38+
self._macro_avg_precision(y_true, y_pred)
39+
if self.macro_averaging
40+
else self._micro_avg_precision(y_true, y_pred)
4141
)
4242

4343
def _micro_avg_precision(

0 commit comments

Comments
 (0)