@@ -45,11 +45,13 @@ class MetricWrapper(nn.Module):
4545 {'entropy': [], 'f1': [], 'precision': []}
4646 """
4747
48- def __init__ (self , * metrics , num_classes , macro_averaging = False ):
48+ def __init__ (self , * metrics , num_classes , macro_averaging = False , ** kwargs ):
4949 super ().__init__ ()
5050 self .metrics = {}
51- self .num_classes = num_classes
52- self .macro_averaging = macro_averaging
51+ self .params = {
52+ "num_classes" : num_classes ,
53+ "macro_averaging" : macro_averaging ,
54+ } | kwargs
5355
5456 for metric in metrics :
5557 self .metrics [metric ] = self ._get_metric (metric )
@@ -73,23 +75,15 @@ def _get_metric(self, key):
7375
7476 match key .lower ():
7577 case "entropy" :
76- return EntropyPrediction (num_classes = self .num_classes )
78+ return EntropyPrediction (** self .params )
7779 case "f1" :
78- return F1Score (
79- num_classes = self .num_classes , macro_averaging = self .macro_averaging
80- )
80+ return F1Score (** self .params )
8181 case "recall" :
82- return Recall (
83- num_classes = self .num_classes , macro_averaging = self .macro_averaging
84- )
82+ return Recall (** self .params )
8583 case "precision" :
86- return Precision (
87- num_classes = self .num_classes , macro_averaging = self .macro_averaging
88- )
84+ return Precision (** self .params )
8985 case "accuracy" :
90- return Accuracy (
91- num_classes = self .num_classes , macro_averaging = self .macro_averaging
92- )
86+ return Accuracy (** self .params )
9387 case _:
9488 raise ValueError (f"Metric { key } not supported" )
9589
0 commit comments