forked from xychen2022/VersatileSegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_final_performance.py
More file actions
73 lines (50 loc) · 2.06 KB
/
compute_final_performance.py
File metadata and controls
73 lines (50 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import SimpleITK as sitk
import pandas as pd
import numpy as np
import os
from ID2AbdominalOrgan import ID2Organ
data_path = '/gpfs/fs001/cbica/home/chexiao/MultiOrgan/awesome/MM'
total_classes = 17
col_names = ["ID"]
for organ_id in sorted(ID2Organ.keys()):
col_names.append(ID2Organ[organ_id])
col_names.remove("background")
print("columns: ", col_names, ", num_columns = ", len(col_names))
val_subjects = sorted(os.listdir(data_path + "/labelsVa"))
def one_hot_encoding(label, numClasses):
prob_list = []
for i in range(numClasses):
temp_prob = label == i
prob_list.append(np.expand_dims(temp_prob, axis=0))
return np.concatenate(prob_list, axis=0)
data = []
for idx in range(len(val_subjects)):
subject_id = val_subjects[idx].split('.')[0]
print("Subject: ", subject_id)
groundtruth = sitk.ReadImage(data_path + '/labelsVa/' + subject_id + ".nii.gz")
groundtruth = sitk.GetArrayFromImage(groundtruth)
classes_that_exist, counts = np.unique(groundtruth.astype(np.int32), return_counts=True)
gt = one_hot_encoding(groundtruth, numClasses=total_classes)
print("gt: ", gt.shape)
prediction = sitk.ReadImage('./best_results/' + 'pred_' + subject_id + '.nii.gz')
prediction = sitk.GetArrayFromImage(prediction)
pred = one_hot_encoding(prediction, numClasses=total_classes)
print("pred: ", pred.shape)
dice = [subject_id]
for idx in range(1, total_classes):
if idx in classes_that_exist and counts[np.where(classes_that_exist == idx)] > 50:
pred_i = pred[idx]
gt_i = gt[idx]
intersection = np.sum(pred_i * gt_i)
dsc = (2. * np.sum(intersection) + 1.0) / (np.sum(pred_i) + np.sum(gt_i) + 1.0)
dice.append(dsc)
else:
dice.append(None)
data.append(dice)
print(dice)
print('\n')
# Create the pandas DataFrame
data = np.array(data)
print(data.shape)
df = pd.DataFrame(data, columns=col_names)
df.to_csv('overall_performance.csv', index=False)