Skip to content

Commit 370a1f8

Browse files
committed
Proposing change to metricwrapper class
1 parent 0bd207b commit 370a1f8

File tree

2 files changed

+40
-36
lines changed

2 files changed

+40
-36
lines changed

utils/load_metric.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,71 +8,63 @@
88

99
class MetricWrapper(nn.Module):
1010
"""
11-
Wrapper class for metrics, that runs multiple metrics on the same data.
12-
11+
A wrapper class for evaluating multiple metrics on the same dataset.
12+
This class allows you to compute several metrics simultaneously on given
13+
true and predicted labels. It supports a variety of common metrics and
14+
provides methods to accumulate results and reset the state.
1315
Args
1416
----
17+
num_classes : int
18+
The number of classes in the classification task.
1519
metrics : list[str]
16-
List of metrics to run on the data.
17-
20+
A list of metric names to be evaluated.
1821
Attributes
1922
----------
2023
metrics : dict
21-
Dictionary containing the metric functions.
22-
tmp_scores : dict
23-
Dictionary containing the temporary scores of the metrics.
24-
24+
A dictionary mapping metric names to their corresponding functions.
25+
num_classes : int
26+
The number of classes for the classification task.
2527
Methods
2628
-------
2729
__call__(y_true, y_pred)
28-
Call the metric functions on the true and predicted labels.
29-
accumulate()
30-
Get the average scores of the metrics.
30+
Computes the specified metrics on the provided true and predicted labels.
31+
__getmetrics__(str_prefix: str = None)
32+
Retrieves the computed metrics, optionally prefixed with a string.
3133
reset()
32-
Reset the temporary scores of the metrics.
33-
34+
Resets the state of all metric computations.
3435
Examples
3536
--------
36-
>>> from utils import MetricWrapper
37-
>>> metrics = MetricWrapper("entropy", "f1", "precision")
37+
>>> from utils import MetricWrapperProposed
38+
>>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision")
3839
>>> y_true = [0, 1, 0, 1]
3940
>>> y_pred = [0, 1, 1, 0]
4041
>>> metrics(y_true, y_pred)
41-
>>> metrics.accumulate()
42+
>>> metrics.__getmetrics__()
4243
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
4344
>>> metrics.reset()
44-
>>> metrics.accumulate()
45+
>>> metrics.__getmetrics__()
4546
{'entropy': [], 'f1': [], 'precision': []}
4647
"""
4748

4849
def __init__(self, num_classes, *metrics):
4950
super().__init__()
5051
self.metrics = {}
5152
self.num_classes = num_classes
52-
5353
for metric in metrics:
5454
self.metrics[metric] = self._get_metric(metric)
5555

56-
self.tmp_scores = copy.deepcopy(self.metrics)
57-
for key in self.tmp_scores:
58-
self.tmp_scores[key] = []
59-
6056
def _get_metric(self, key):
6157
"""
62-
Get the metric function based on the key
63-
58+
Retrieves the metric function based on the provided key.
6459
Args
6560
----
66-
key (str): metric name
67-
61+
key (str): The name of the metric.
6862
Returns
6963
-------
70-
metric (callable): metric function
64+
metric (callable): The function that computes the metric.
7165
"""
72-
7366
match key.lower():
7467
case "entropy":
75-
# Not dependent on knowing the number of classes
7668
return EntropyPrediction()
7769
case "f1":
7870
return F1Score(num_classes=self.num_classes)
@@ -87,18 +79,17 @@ def _get_metric(self, key):
8779

8880
def __call__(self, y_true, y_pred):
8981
for key in self.metrics:
90-
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
82+
self.metrics[key](y_true, y_pred)
9183

92-
def accumulate(self, str_prefix: str = None):
84+
def __getmetrics__(self, str_prefix: str = None):
9385
return_metrics = {}
9486
for key in self.metrics:
9587
if str_prefix is not None:
96-
return_metrics[str_prefix + key] = np.mean(self.tmp_scores[key])
88+
return_metrics[str_prefix + key] = self.metrics[key].__returnmetric__()
9789
else:
98-
return_metrics[key] = np.mean(self.tmp_scores[key])
99-
90+
return_metrics[key] = self.metrics[key].__returnmetric__()
10091
return return_metrics
10192

10293
def reset(self):
103-
for key in self.tmp_scores:
104-
self.tmp_scores[key] = []
94+
for key in self.metrics:
95+
self.metrics[key].reset()

utils/metrics/EntropyPred.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,21 @@ def __call__(self, y_true, y_logits):
4646
elif self.averages == "none":
4747
return entropy_values
4848

49+
self.stored_entropy_values.append(entropy_values)
50+
4951
return entropy_values
5052

53+
def return_value(self):
54+
if self.averages == "mean":
55+
self.stored_entropy_values = th.mean(self.stored_entropy_values)
56+
elif self.averages == "sum":
57+
self.stored_entropy_values = th.sum(self.stored_entropy_values)
58+
elif self.averages == "none":
59+
return self.stored_entropy_values
60+
61+
def reset(self):
62+
self.stored_entropy_values = []
63+
5164

5265
if __name__ == "__main__":
5366
import torch as th

0 commit comments

Comments
 (0)