-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathft_metrics.py
More file actions
executable file
·53 lines (45 loc) · 2.03 KB
/
ft_metrics.py
File metadata and controls
executable file
·53 lines (45 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python
#!/usr/local/bin/python3
# @author cpuhrsch https://github.com/cpuhrsch
# @author Loreto Parisi loreto@musixmatch.com
# On 2018/09/08 modified by Yuen-Hsien Tseng from:
# https://gist.github.com/loretoparisi/41b918add11893d761d0ec12a3a4e1aa
import argparse
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import classification_report
def parse_labels(path):
with open(path, 'r') as f:
# return np.array(list(map(lambda x: x[9:], f.read().split())))
return np.array(list(map(lambda x: x[9:], f.read().split())))
def tcfunc(x, n=4): # trancate a number to have n decimal digits
d = '0' * n
d = int('1' + d)
# https://stackoverflow.com/questions/4541155/check-if-a-number-is-int-or-float
if isinstance(x, (int, float)): return int(x * d) / d
return x
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Display confusion matrix.')
parser.add_argument('test', help='Path to test labels')
parser.add_argument('predict', help='Path to predictions')
args = parser.parse_args()
test_labels = parse_labels(args.test)
pred_labels = parse_labels(args.predict)
eq = test_labels == pred_labels
# print("Accuracy: " + str(eq.sum() / len(test_labels)))
# print(confusion_matrix(test_labels, pred_labels))
print("\tPrecision\tRecall\tF1\tSupport")
(Precision, Recall, F1, Support) = list(map(tcfunc,
precision_recall_fscore_support(test_labels, pred_labels, average='micro')))
print("Micro\t{}\t{}\t{}\t{}".format(Precision, Recall, F1, Support))
(Precision, Recall, F1, Support) = list(map(tcfunc,
precision_recall_fscore_support(test_labels, pred_labels, average='macro')))
print("Macro\t{}\t{}\t{}\t{}".format(Precision, Recall, F1, Support))
exit()
'''
try:
print(classification_report(test_labels, pred_labels, digits=4))
except ValueError:
print('May be some category has no predicted samples')
'''