Skip to content

Commit cba9b80

Browse files
committed
Simplify how metrics are parsed
Use a switch statement to handle the different cases.
1 parent 1a664b7 commit cba9b80

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

main.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,8 @@ def main():
3939
parser.add_argument('--dataset', type=str, default='svhn',
4040
choices=['svhn'], help='Which dataset to train the model on.')
4141

42-
parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation')
43-
parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation')
44-
parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation')
45-
parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation')
46-
parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation')
47-
42+
parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation')
43+
4844
#Training specific values
4945
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
5046
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
@@ -61,13 +57,7 @@ def main():
6157
model = load_model()
6258
model.to(device)
6359

64-
metrics = MetricWrapper(
65-
EntropyPred = args.EntropyPrediction,
66-
F1Score = args.F1Score,
67-
Recall = args.Recall,
68-
Precision = args.Precision,
69-
Accuracy = args.Accuracy
70-
)
60+
metrics = MetricWrapper(*args.metric)
7161

7262
#Dataset
7363
traindata = load_data(args.dataset)
@@ -126,4 +116,4 @@ def main():
126116

127117

128118
if __name__ == '__main__':
129-
main()
119+
main()

utils/load_metric.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,45 @@
55

66

77
class MetricWrapper(nn.Module):
8-
def __init__(self,
9-
EntropyPred:bool = True,
10-
F1Score:bool = True,
11-
Recall:bool = True,
12-
Precision:bool = True,
13-
Accuracy:bool = True):
8+
def __init__(self, *metrics):
149
super().__init__()
1510
self.metrics = {}
16-
17-
if EntropyPred:
18-
self.metrics['Entropy of Predictions'] = EntropyPrediction()
1911

20-
if F1Score:
21-
self.metrics['F1 Score'] = None
22-
23-
if Recall:
24-
self.metrics['Recall'] = None
12+
for metric in metrics:
13+
self.metrics[metric] = self._get_metric(metric)
2514

26-
if Precision:
27-
self.metrics['Precision'] = None
28-
29-
if Accuracy:
30-
self.metrics['Accuracy'] = None
31-
3215
self.tmp_scores = copy.deepcopy(self.metrics)
3316
for key in self.tmp_scores:
3417
self.tmp_scores[key] = []
3518

19+
20+
def _get_metric(self, key):
21+
"""
22+
Get the metric function based on the key
23+
24+
Args
25+
----
26+
key (str): metric name
27+
28+
Returns
29+
-------
30+
metric (callable): metric function
31+
"""
32+
33+
match key.lower():
34+
case 'entropy':
35+
return EntropyPrediction()
36+
case 'f1':
37+
raise NotImplementedError("F1 score not implemented yet")
38+
case 'recall':
39+
raise NotImplementedError("Recall score not implemented yet")
40+
case 'precision':
41+
raise NotImplementedError("Precision score not implemented yet")
42+
case 'accuracy':
43+
raise NotImplementedError("Accuracy score not implemented yet")
44+
case _:
45+
raise ValueError(f"Metric {key} not supported")
46+
3647
def __call__(self, y_true, y_pred):
3748
for key in self.metrics:
3849
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))

0 commit comments

Comments
 (0)