Skip to content

Commit 2033d86

Browse files
committed
set cls_mtbc as nonexclusive classification
1 parent 4adbb37 commit 2033d86

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

pymic/net_run/agent_cls.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ClassificationAgent(NetRunAgent):
3737
def __init__(self, config, stage = 'train'):
3838
super(ClassificationAgent, self).__init__(config, stage)
3939
self.transform_dict = TransformDict
40+
assert(self.task_type in ["cls", "cls_nexcl"])
4041

4142
def get_stage_dataset_from_config(self, stage):
4243
assert(stage in ['train', 'valid', 'test'])
@@ -106,7 +107,7 @@ def get_evaluation_score(self, outputs, labels):
106107
_, preds = torch.max(outputs, 1)
107108
consis= self.convert_tensor_type(preds == labels.data)
108109
score = torch.mean(consis)
109-
elif(self.task_type == "cls_mtbc"): #multi-task binary classification
110+
elif(self.task_type == "cls_nexcl"): #nonexclusive classification
110111
preds = self.convert_tensor_type(outputs > 0.5)
111112
consis= self.convert_tensor_type(preds == labels.data)
112113
score = torch.mean(consis)
@@ -299,7 +300,7 @@ def infer(self):
299300
if (self.task_type == "cls"):
300301
out_prob = nn.Softmax(dim = 1)(out_digit).detach().cpu().numpy()
301302
out_lab = np.argmax(out_prob, axis=1)
302-
else: #self.task_type == "cls_mtbc"
303+
else: #self.task_type == "cls_nexcl"
303304
out_prob = nn.Sigmoid()(out_digit).detach().cpu().numpy()
304305
out_lab = np.asarray(out_prob > 0.5, np.uint8)
305306
for i in range(len(names)):

pymic/util/evaluation_cls.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def get_evaluation_score(gt_label, pred_prob, metric):
5252
raise ValueError("undefined metric: {0:}".format(metric))
5353
return score
5454

55-
def binary_evaluation(config_file):
56-
config = parse_config(config_file)['evaluation']
55+
def binary_evaluation(config):
5756
metric_list = config['metric_list']
5857
gt_csv = config['ground_truth_csv']
5958
prob_csv= config['predict_prob_csv']
@@ -71,6 +70,55 @@ def binary_evaluation(config_file):
7170
score_list.append(score)
7271
print("{0:}: {1:}".format(metric, score))
7372

73+
out_csv = prob_csv.replace("prob", "eval")
74+
with open(out_csv, mode='w') as csv_file:
75+
csv_writer = csv.writer(csv_file, delimiter=',',
76+
quotechar='"',quoting=csv.QUOTE_MINIMAL)
77+
csv_writer.writerow(metric_list)
78+
csv_writer.writerow(score_list)
79+
80+
def nexcl_evaluation(config):
81+
"""
82+
evaluation for nonexclusive classification
83+
"""
84+
metric_list = config['metric_list']
85+
gt_csv = config['ground_truth_csv']
86+
prob_csv = config['predict_prob_csv']
87+
gt_items = pd.read_csv(gt_csv)
88+
prob_items= pd.read_csv(prob_csv)
89+
assert(len(gt_items) == len(prob_items))
90+
for i in range(len(gt_items)):
91+
assert(gt_items.iloc[i, 0] == prob_items.iloc[i, 0])
92+
93+
cls_names = gt_items.columns[1:]
94+
cls_num = len(cls_names)
95+
gt_data = np.asarray(gt_items.iloc[:, 1:cls_num + 1])
96+
prob_data = np.asarray(prob_items.iloc[:, 1:cls_num + 1])
97+
score_list= []
98+
for metric in metric_list:
99+
print(metric)
100+
score_m = []
101+
for c in range(cls_num):
102+
gt_data_c = gt_data[:, c:c+1]
103+
prob_c = prob_data[:, c]
104+
prob_c = np.asarray([1.0 - prob_c, prob_c])
105+
prob_c = np.transpose(prob_c)
106+
score = get_evaluation_score(gt_data_c, prob_c, metric)
107+
score_m.append(score)
108+
print(cls_names[c], score)
109+
score_avg = np.asarray(score_m).mean()
110+
print('avg', score_avg)
111+
score_m.append(score_avg)
112+
score_list.append(score_m)
113+
114+
out_csv = prob_csv.replace("prob", "eval")
115+
with open(out_csv, mode='w') as csv_file:
116+
csv_writer = csv.writer(csv_file, delimiter=',',
117+
quotechar='"',quoting=csv.QUOTE_MINIMAL)
118+
csv_writer.writerow(['metric'] + list(cls_names) + ['avg'])
119+
for i in range(len(score_list)):
120+
item = metric_list[i : i+1] + score_list[i]
121+
csv_writer.writerow(item)
74122

75123
def main():
76124
if(len(sys.argv) < 2):
@@ -79,7 +127,12 @@ def main():
79127
exit()
80128
config_file = str(sys.argv[1])
81129
assert(os.path.isfile(config_file))
82-
binary_evaluation(config_file)
130+
config = parse_config(config_file)['evaluation']
131+
task_type = config.get('task_type', "cls")
132+
if(task_type == "cls"): # default exclusive classification
133+
binary_evaluation(config)
134+
else: # non exclusive classification
135+
nexcl_evaluation(config)
83136

84137
if __name__ == '__main__':
85138
main()

0 commit comments

Comments
 (0)