@@ -52,8 +52,7 @@ def get_evaluation_score(gt_label, pred_prob, metric):
5252 raise ValueError ("undefined metric: {0:}" .format (metric ))
5353 return score
5454
55- def binary_evaluation (config_file ):
56- config = parse_config (config_file )['evaluation' ]
55+ def binary_evaluation (config ):
5756 metric_list = config ['metric_list' ]
5857 gt_csv = config ['ground_truth_csv' ]
5958 prob_csv = config ['predict_prob_csv' ]
@@ -71,6 +70,55 @@ def binary_evaluation(config_file):
7170 score_list .append (score )
7271 print ("{0:}: {1:}" .format (metric , score ))
7372
73+ out_csv = prob_csv .replace ("prob" , "eval" )
74+ with open (out_csv , mode = 'w' ) as csv_file :
75+ csv_writer = csv .writer (csv_file , delimiter = ',' ,
76+ quotechar = '"' ,quoting = csv .QUOTE_MINIMAL )
77+ csv_writer .writerow (metric_list )
78+ csv_writer .writerow (score_list )
79+
80+ def nexcl_evaluation (config ):
81+ """
82+ evaluation for nonexclusive classification
83+ """
84+ metric_list = config ['metric_list' ]
85+ gt_csv = config ['ground_truth_csv' ]
86+ prob_csv = config ['predict_prob_csv' ]
87+ gt_items = pd .read_csv (gt_csv )
88+ prob_items = pd .read_csv (prob_csv )
89+ assert (len (gt_items ) == len (prob_items ))
90+ for i in range (len (gt_items )):
91+ assert (gt_items .iloc [i , 0 ] == prob_items .iloc [i , 0 ])
92+
93+ cls_names = gt_items .columns [1 :]
94+ cls_num = len (cls_names )
95+ gt_data = np .asarray (gt_items .iloc [:, 1 :cls_num + 1 ])
96+ prob_data = np .asarray (prob_items .iloc [:, 1 :cls_num + 1 ])
97+ score_list = []
98+ for metric in metric_list :
99+ print (metric )
100+ score_m = []
101+ for c in range (cls_num ):
102+ gt_data_c = gt_data [:, c :c + 1 ]
103+ prob_c = prob_data [:, c ]
104+ prob_c = np .asarray ([1.0 - prob_c , prob_c ])
105+ prob_c = np .transpose (prob_c )
106+ score = get_evaluation_score (gt_data_c , prob_c , metric )
107+ score_m .append (score )
108+ print (cls_names [c ], score )
109+ score_avg = np .asarray (score_m ).mean ()
110+ print ('avg' , score_avg )
111+ score_m .append (score_avg )
112+ score_list .append (score_m )
113+
114+ out_csv = prob_csv .replace ("prob" , "eval" )
115+ with open (out_csv , mode = 'w' ) as csv_file :
116+ csv_writer = csv .writer (csv_file , delimiter = ',' ,
117+ quotechar = '"' ,quoting = csv .QUOTE_MINIMAL )
118+ csv_writer .writerow (['metric' ] + list (cls_names ) + ['avg' ])
119+ for i in range (len (score_list )):
120+ item = metric_list [i : i + 1 ] + score_list [i ]
121+ csv_writer .writerow (item )
74122
75123def main ():
76124 if (len (sys .argv ) < 2 ):
@@ -79,7 +127,12 @@ def main():
79127 exit ()
80128 config_file = str (sys .argv [1 ])
81129 assert (os .path .isfile (config_file ))
82- binary_evaluation (config_file )
130+ config = parse_config (config_file )['evaluation' ]
131+ task_type = config .get ('task_type' , "cls" )
132+ if (task_type == "cls" ): # default exclusive classification
133+ binary_evaluation (config )
134+ else : # non exclusive classification
135+ nexcl_evaluation (config )
83136
84137if __name__ == '__main__' :
85138 main ()
0 commit comments