Skip to content

Commit 08aa876

Browse files
committed
Modified the MetricWrappers arguments being passed on
This will hopefully simplify the arguments to each metric slightly.
1 parent 2885a30 commit 08aa876

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

utils/load_metric.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ class MetricWrapper(nn.Module):
4545
{'entropy': [], 'f1': [], 'precision': []}
4646
"""
4747

48-
def __init__(self, *metrics, num_classes, macro_averaging=False):
48+
def __init__(self, *metrics, num_classes, macro_averaging=False, **kwargs):
4949
super().__init__()
5050
self.metrics = {}
51-
self.num_classes = num_classes
52-
self.macro_averaging = macro_averaging
51+
self.params = {
52+
"num_classes": num_classes,
53+
"macro_averaging": macro_averaging,
54+
} | kwargs
5355

5456
for metric in metrics:
5557
self.metrics[metric] = self._get_metric(metric)
@@ -73,23 +75,15 @@ def _get_metric(self, key):
7375

7476
match key.lower():
7577
case "entropy":
76-
return EntropyPrediction(num_classes=self.num_classes)
78+
return EntropyPrediction(**self.params)
7779
case "f1":
78-
return F1Score(
79-
num_classes=self.num_classes, macro_averaging=self.macro_averaging
80-
)
80+
return F1Score(**self.params)
8181
case "recall":
82-
return Recall(
83-
num_classes=self.num_classes, macro_averaging=self.macro_averaging
84-
)
82+
return Recall(**self.params)
8583
case "precision":
86-
return Precision(
87-
num_classes=self.num_classes, macro_averaging=self.macro_averaging
88-
)
84+
return Precision(**self.params)
8985
case "accuracy":
90-
return Accuracy(
91-
num_classes=self.num_classes, macro_averaging=self.macro_averaging
92-
)
86+
return Accuracy(**self.params)
9387
case _:
9488
raise ValueError(f"Metric {key} not supported")
9589

0 commit comments

Comments
 (0)