Skip to content

Commit 8d04d64

Browse files
committed
added cristae eval script
1 parent 71a0343 commit 8d04d64

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import argparse
2+
from glob import glob
3+
import os
4+
5+
# import h5py
6+
from elf.io import open_file
7+
# from tifffile import imread
8+
import pandas as pd
9+
10+
from elf.evaluation import matching, symmetric_best_dice_score
11+
12+
13+
def evaluate(labels, vesicles):
14+
assert labels.shape == vesicles.shape
15+
stats = matching(vesicles, labels)
16+
sbd = symmetric_best_dice_score(vesicles, labels)
17+
return [stats["f1"], stats["precision"], stats["recall"], sbd]
18+
19+
20+
def summarize_eval(results):
21+
summary = (
22+
results[["dataset", "f1-score", "precision", "recall", "SBD score"]]
23+
.groupby("dataset")
24+
.mean()
25+
.reset_index("dataset")
26+
)
27+
total = results[["f1-score", "precision", "recall", "SBD score"]].mean().values.tolist()
28+
summary.iloc[-1] = ["all"] + total
29+
table = summary.to_markdown(index=False)
30+
print(table)
31+
32+
33+
def evaluate_file(labels_path, seg_path, model_name, segment_key, anno_key, mask_key, output_folder):
34+
print(f"Evaluate labels \n{labels_path} and segmentations \n{seg_path}")
35+
labels, seg = None, None
36+
labels = open_file(labels_path)
37+
seg = open_file(seg_path)
38+
if segment_key is not None:
39+
seg = seg[segment_key][:]
40+
if anno_key is not None:
41+
labels = labels[anno_key][:]
42+
if labels is None or seg is None:
43+
print("Could not find label file for", seg_path)
44+
print("Skipping...")
45+
46+
# evaluate the match of ground truth and vesicles
47+
scores = evaluate(labels, seg)
48+
49+
# store results
50+
result_folder = output_folder
51+
os.makedirs(result_folder, exist_ok=True)
52+
result_path = os.path.join(result_folder, f"evaluation_{model_name}.csv")
53+
print("Evaluation results are saved to:", result_path)
54+
55+
if os.path.exists(result_path):
56+
results = pd.read_csv(result_path)
57+
else:
58+
results = None
59+
ds_name = os.path.basename(os.path.dirname(labels_path))
60+
tomo = os.path.basename(labels_path)
61+
res = pd.DataFrame(
62+
[[ds_name, tomo] + scores], columns=["dataset", "tomogram", "f1-score", "precision", "recall", "SBD score"]
63+
)
64+
if results is None:
65+
results = res
66+
else:
67+
results = pd.concat([results, res])
68+
results.to_csv(result_path, index=False)
69+
70+
# print results
71+
summarize_eval(results)
72+
73+
74+
def evaluate_folder(labels_path, segmentation_path, model_name, segment_key,
75+
anno_key, mask_key, output_folder, ext=".tif"):
76+
print(f"Evaluating folder {segmentation_path}")
77+
print(f"Using labels stored in {labels_path}")
78+
79+
label_paths = get_file_paths(labels_path, ext=ext)
80+
seg_paths = get_file_paths(segmentation_path, ext=ext)
81+
if label_paths is None or seg_paths is None:
82+
print("Could not find label file or segmentation file")
83+
return
84+
85+
for seg_path in seg_paths:
86+
label_path = find_label_file(seg_path, label_paths)
87+
if label_path is not None:
88+
evaluate_file(label_path, seg_path, model_name, segment_key, anno_key, mask_key, output_folder)
89+
else:
90+
print("Could not find label file for", seg_path)
91+
print("Skipping...")
92+
93+
94+
def get_file_paths(path, ext=".h5", reverse=False):
95+
if ext in path:
96+
return [path]
97+
else:
98+
paths = sorted(glob(os.path.join(path, "**", f"*{ext}"), recursive=True), reverse=reverse)
99+
return paths
100+
101+
102+
def find_label_file(given_path: str, label_paths: list) -> str:
103+
"""
104+
Find the corresponding label file for a given raw file.
105+
Args:
106+
given_path (str): The path we want to find label file to.
107+
label_paths (list): A list of label file paths.
108+
Returns:
109+
str: The path to the matching label file, or None if no match is found.
110+
"""
111+
raw_base = os.path.splitext(os.path.basename(given_path))[0] # Remove extension
112+
113+
for label_path in label_paths:
114+
label_base = os.path.splitext(os.path.basename(label_path))[0] # Remove extension
115+
if raw_base in label_base: # Ensure raw name is contained in label name
116+
return label_path
117+
118+
return None # No match found
119+
120+
121+
def main():
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument("-sp", "--segmentation_path", required=True,
124+
default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/")
125+
parser.add_argument("-gp", "--groundtruth_path", required=True,
126+
default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/")
127+
parser.add_argument("-n", "--model_name", required=True)
128+
parser.add_argument("-sk", "--segmentation_key", default=None, default="labels/new_cristae_seg")
129+
parser.add_argument("-gk", "--groundtruth_key", default=None, default="labels/cristae")
130+
parser.add_argument("-m", "--mask_key", default=None)
131+
parser.add_argument(
132+
"-o", "--output_folder", required=True,
133+
default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/eval"
134+
)
135+
args = parser.parse_args()
136+
137+
if os.path.isdir(args.segmentation_path):
138+
evaluate_folder(args.groundtruth_path, args.segmentation_path, args.model_name, args.segmentation_key,
139+
args.groundtruth_key,
140+
args.mask_key, args.output_folder)
141+
else:
142+
evaluate_file(args.groundtruth_path, args.segmentation_path, args.model_name, args.segmentation_key,
143+
args.groundtruth_key,
144+
args.mask_key, args.output_folder)
145+
146+
147+
if __name__ == "__main__":
148+
main()

0 commit comments

Comments
 (0)