88
99class MetricWrapper (nn .Module ):
1010 """
11- Wrapper class for metrics, that runs multiple metrics on the same data.
12-
11+ A wrapper class for evaluating multiple metrics on the same dataset.
12+ This class allows you to compute several metrics simultaneously on given
13+ true and predicted labels. It supports a variety of common metrics and
14+ provides methods to accumulate results and reset the state.
1315 Args
1416 ----
17+ num_classes : int
18+ The number of classes in the classification task.
1519 metrics : list[str]
16- List of metrics to run on the data.
17-
20+ A list of metric names to be evaluated.
1821 Attributes
1922 ----------
2023 metrics : dict
21- Dictionary containing the metric functions.
22- tmp_scores : dict
23- Dictionary containing the temporary scores of the metrics.
24-
24+ A dictionary mapping metric names to their corresponding functions.
25+ num_classes : int
26+ The number of classes for the classification task.
2527 Methods
2628 -------
2729 __call__(y_true, y_pred)
28- Call the metric functions on the true and predicted labels.
29- accumulate( )
30- Get the average scores of the metrics .
30+ Computes the specified metrics on the provided true and predicted labels.
31+ __getmetrics__(str_prefix: str = None )
32+ Retrieves the computed metrics, optionally prefixed with a string .
3133 reset()
32- Reset the temporary scores of the metrics.
33-
34+ Resets the state of all metric computations.
3435 Examples
3536 --------
36- >>> from utils import MetricWrapper
37- >>> metrics = MetricWrapper( "entropy", "f1", "precision")
37+ >>> from utils import MetricWrapperProposed
38+ >>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision")
3839 >>> y_true = [0, 1, 0, 1]
3940 >>> y_pred = [0, 1, 1, 0]
4041 >>> metrics(y_true, y_pred)
41- >>> metrics.accumulate ()
42+ >>> metrics.__getmetrics__ ()
4243 {'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
4344 >>> metrics.reset()
44- >>> metrics.accumulate ()
45+ >>> metrics.__getmetrics__ ()
4546 {'entropy': [], 'f1': [], 'precision': []}
4647 """
4748
4849 def __init__ (self , num_classes , * metrics ):
4950 super ().__init__ ()
5051 self .metrics = {}
5152 self .num_classes = num_classes
52-
5353 for metric in metrics :
5454 self .metrics [metric ] = self ._get_metric (metric )
5555
56- self .tmp_scores = copy .deepcopy (self .metrics )
57- for key in self .tmp_scores :
58- self .tmp_scores [key ] = []
59-
6056 def _get_metric (self , key ):
6157 """
62- Get the metric function based on the key
63-
58+ Retrieves the metric function based on the provided key.
6459 Args
6560 ----
66- key (str): metric name
67-
61+ key (str): The name of the metric.
6862 Returns
6963 -------
70- metric (callable): metric function
64+ metric (callable): The function that computes the metric.
7165 """
72-
7366 match key .lower ():
7467 case "entropy" :
75- # Not dependent on knowing the number of classes
7668 return EntropyPrediction ()
7769 case "f1" :
7870 return F1Score (num_classes = self .num_classes )
@@ -87,18 +79,17 @@ def _get_metric(self, key):
8779
8880 def __call__ (self , y_true , y_pred ):
8981 for key in self .metrics :
90- self .tmp_scores [ key ]. append ( self . metrics [key ](y_true , y_pred ) )
82+ self .metrics [key ](y_true , y_pred )
9183
92- def accumulate (self , str_prefix : str = None ):
84+ def __getmetrics__ (self , str_prefix : str = None ):
9385 return_metrics = {}
9486 for key in self .metrics :
9587 if str_prefix is not None :
96- return_metrics [str_prefix + key ] = np . mean ( self .tmp_scores [key ])
88+ return_metrics [str_prefix + key ] = self .metrics [key ]. __returnmetric__ ( )
9789 else :
98- return_metrics [key ] = np .mean (self .tmp_scores [key ])
99-
90+ return_metrics [key ] = self .metrics [key ].__returnmetric__ ()
10091 return return_metrics
10192
10293 def reset (self ):
103- for key in self .tmp_scores :
104- self .tmp_scores [key ] = []
94+ for key in self .metrics :
95+ self .metrics [key ]. reset ()
0 commit comments