Skip to content

Commit ff11f88

Browse files
Add more evaluation scripts
1 parent 1f392ef commit ff11f88

File tree

7 files changed

+410
-12
lines changed

7 files changed

+410
-12
lines changed

scripts/cooper/full_reconstruction/qualitative_evaluation.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,101 @@
55
import pandas as pd
66
import napari
77

8+
from skimage.measure import label
9+
810
from tqdm import tqdm
911

1012
val_table = "/home/pape/Desktop/sfb1286/mboc_synapse/qualitative-stem-eval.xlsx"
1113
val_table = pd.read_excel(val_table)
1214

1315

16+
def _get_n_azs(path):
17+
access = np.s_[::2, ::2, ::2]
18+
with h5py.File(path, "r") as f:
19+
az = f["labels/active_zone"][access]
20+
az = label(az)
21+
ids, sizes = np.unique(az, return_counts=True)
22+
ids, sizes = ids[1:], sizes[1:]
23+
n_azs = np.sum(sizes > 10000)
24+
return n_azs, n_azs
25+
26+
1427
def eval_az():
15-
az_found = []
16-
az_total = []
28+
azs_found = []
29+
azs_total = []
1730

18-
# TODO for the "all" tomograms load the prediction, measure number components,
31+
# for the "all" tomograms load the prediction, measure number components,
1932
# size filter and count these as found and as total
20-
for i, row in val_table.iterrows():
21-
pass
33+
for i, row in tqdm(val_table.iterrows(), total=len(val_table)):
34+
az_found = row["AZ Found"]
35+
if az_found == "all":
36+
path = os.path.join("04_full_reconstruction", row.dataset, row.tomogram)
37+
assert os.path.exists(path)
38+
az_found, az_total = _get_n_azs(path)
39+
else:
40+
az_total = row["AZ Total"]
41+
42+
azs_found.append(az_found)
43+
azs_total.append(az_total)
2244

45+
n_found = np.sum(azs_found)
46+
n_azs = np.sum(azs_total)
2347

24-
# TODO measure in how many pieces each compartment was split
48+
print("AZ Evaluation:")
49+
print("Number of correctly identified AZs:", n_found, "/", n_azs, f"({float(n_found)/n_azs}%)")
50+
51+
52+
# measure in how many pieces each compartment was split
2553
def eval_compartments():
26-
pass
54+
pieces_per_compartment = []
55+
for i, row in val_table.iterrows():
56+
for comp in [
57+
"Compartment 1",
58+
"Compartment 2",
59+
"Compartment 3",
60+
"Compartment 4",
61+
]:
62+
n_pieces = row[comp]
63+
if isinstance(n_pieces, str):
64+
n_pieces = len(n_pieces.split(","))
65+
elif np.isnan(n_pieces):
66+
continue
67+
else:
68+
assert isinstance(n_pieces, (float, int))
69+
n_pieces = 1
70+
pieces_per_compartment.append(n_pieces)
71+
72+
avg = np.mean(pieces_per_compartment)
73+
std = np.std(pieces_per_compartment)
74+
max_ = np.max(pieces_per_compartment)
75+
print("Compartment Evaluation:")
76+
print("Avergage pieces per compartment:", avg, "+-", std)
77+
print("Max pieces per compartment:", max_)
2778

2879

2980
def eval_mitos():
3081
mito_correct = []
3182
mito_split = []
3283
mito_merged = []
3384
mito_total = []
34-
35-
# TODO measure % of mito correct, mito split and mito merged
36-
for i, row in val_table.iterrows():
37-
pass
85+
wrong_object = []
86+
87+
mito_table = val_table.fillna(0)
88+
# measure % of mito correct, mito split and mito merged
89+
for i, row in mito_table.iterrows():
90+
mito_correct.append(row["Mito Correct"])
91+
mito_split.append(row["Mito Split"])
92+
mito_merged.append(row["Mito Merged"])
93+
mito_total.append(row["Mito Total"])
94+
wrong_object.append(row["Wrong Object"])
95+
96+
n_mitos = np.sum(mito_total)
97+
n_correct = np.sum(mito_correct)
98+
print("Mito Evaluation:")
99+
print("Number of correctly identified mitos:", n_correct, "/", n_mitos, f"({float(n_correct)/n_mitos}%)")
100+
print("Number of merged mitos:", np.sum(mito_merged))
101+
print("Number of split mitos:", np.sum(mito_split))
102+
print("Number of wrongly identified objects:", np.sum(wrong_object))
38103

39104

40105
def check_mitos():
@@ -57,7 +122,13 @@ def check_mitos():
57122

58123

59124
def main():
60-
check_mitos()
125+
# check_mitos()
126+
127+
eval_mitos()
128+
print()
129+
eval_compartments()
130+
print()
131+
eval_az()
61132

62133

63134
main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
data/
2+
exported/
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import napari
6+
import numpy as np
7+
from elf.io import open_file
8+
from magicgui import magicgui
9+
from synaptic_reconstruction.imod.export import export_point_annotations
10+
11+
EXPORT_FOLDER = "./exported"
12+
13+
14+
def export_vesicles(mrc, mod):
15+
os.makedirs(EXPORT_FOLDER, exist_ok=True)
16+
17+
fname = Path(mrc).stem
18+
output_path = os.path.join(EXPORT_FOLDER, f"{fname}.h5")
19+
if os.path.exists(output_path):
20+
return
21+
22+
resolution = 0.592
23+
with open_file(mrc, "r") as f:
24+
data = f["data"][:]
25+
26+
segmentation, labels, label_names = export_point_annotations(
27+
mod, shape=data.shape, resolution=resolution, exclude_labels=[7, 14]
28+
)
29+
data, segmentation = data[0], segmentation[0]
30+
31+
with open_file(output_path, "a") as f:
32+
f.create_dataset("data", data=data, compression="gzip")
33+
f.create_dataset("labels/vesicles", data=segmentation, compression="gzip")
34+
35+
36+
def export_all_vesicles():
37+
mrc_files = sorted(glob(os.path.join("./data/*.mrc")))
38+
mod_files = sorted(glob(os.path.join("./data/*.mod")))
39+
for mrc, mod in zip(mrc_files, mod_files):
40+
export_vesicles(mrc, mod)
41+
42+
43+
def create_mask(file_path):
44+
with open_file(file_path, "r") as f:
45+
if "labels/mask" in f:
46+
return
47+
48+
data = f["data"][:]
49+
vesicles = f["labels/vesicles"][:]
50+
51+
mask = np.zeros_like(vesicles)
52+
53+
v = napari.Viewer()
54+
v.add_image(data)
55+
v.add_labels(vesicles)
56+
v.add_labels(mask)
57+
58+
@magicgui(call_button="Save Mask")
59+
def save_mask(v: napari.Viewer):
60+
mask = v.layers["mask"].data.astype("uint8")
61+
with open_file(file_path, "a") as f:
62+
f.create_dataset("labels/mask", data=mask, compression="gzip")
63+
64+
v.window.add_dock_widget(save_mask)
65+
napari.run()
66+
67+
68+
def create_all_masks():
69+
files = sorted(glob(os.path.join(EXPORT_FOLDER, "*.h5")))
70+
for ff in files:
71+
create_mask(ff)
72+
73+
74+
def main():
75+
export_all_vesicles()
76+
create_all_masks()
77+
78+
79+
if __name__ == "__main__":
80+
main()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
AZ_segmentation/
2+
postprocessed_AZ/
3+
az_eval.xlsx
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
3+
import h5py
4+
import napari
5+
6+
from tqdm import tqdm
7+
8+
9+
def check_proofread(raw_path, seg_path):
10+
with h5py.File(seg_path, "r") as f:
11+
seg1 = f["labels_pp/thin_az"][:]
12+
seg2 = f["labels_pp/filtered_az"][:]
13+
with h5py.File(raw_path, "r") as f:
14+
raw = f["raw"][:]
15+
16+
v = napari.Viewer()
17+
v.add_image(raw)
18+
v.add_labels(seg1)
19+
v.add_labels(seg2)
20+
napari.run()
21+
22+
23+
def main():
24+
# FIXME something wrong in the zenodo upload
25+
root_raw = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/data_summary/for_zenodo/synapse-net/active_zones/train" # noqa
26+
root_seg = "./postprocessed_AZ"
27+
28+
test_tomograms = {
29+
"01": [
30+
"WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa
31+
"20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa
32+
"WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa
33+
"20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa
34+
"WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5",
35+
"WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa
36+
"WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5",
37+
"M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa
38+
],
39+
"12": ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5",
40+
"20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"],
41+
}
42+
43+
for ds, test_tomos in test_tomograms.items():
44+
ds_name_raw = "single_axis_tem" if ds == "01" else "chemical-fixation"
45+
ds_name_seg = "01_hoi_maus_2020_incomplete" if ds == "01" else "12_chemical_fix_cryopreparation"
46+
for tomo in tqdm(test_tomos, desc=f"Proofread {ds}"):
47+
raw_path = os.path.join(root_raw, ds_name_raw, tomo)
48+
seg_path = os.path.join(root_seg, ds_name_seg, tomo)
49+
check_proofread(raw_path, seg_path)
50+
51+
52+
main()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
3+
import h5py
4+
import pandas as pd
5+
from elf.evaluation import dice_score
6+
7+
from scipy.ndimage import binary_dilation, binary_closing
8+
from tqdm import tqdm
9+
10+
11+
def _expand_AZ(az):
12+
return binary_closing(
13+
binary_dilation(az, iterations=3), iterations=3
14+
)
15+
16+
17+
def eval_az(seg_path, gt_path, seg_key, gt_key):
18+
with h5py.File(seg_path, "r") as f:
19+
seg = f[seg_key][:]
20+
with h5py.File(gt_path, "r") as f:
21+
gt = f[gt_key][:]
22+
assert seg.shape == gt.shape
23+
24+
seg = _expand_AZ(seg)
25+
gt = _expand_AZ(gt)
26+
score = dice_score(seg, gt)
27+
28+
# import napari
29+
# v = napari.Viewer()
30+
# v.add_labels(seg)
31+
# v.add_labels(gt)
32+
# v.title = f"Dice = {score}, {seg_path}"
33+
# napari.run()
34+
35+
return score
36+
37+
38+
def main():
39+
res_path = "./az_eval.xlsx"
40+
if not os.path.exists(res_path):
41+
seg_root = "AZ_segmentation/postprocessed_AZ"
42+
gt_root = "postprocessed_AZ"
43+
44+
# Removed WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5 from the eval set because of contrast issues
45+
test_tomograms = {
46+
"01": [
47+
"WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa
48+
"20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa
49+
"WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa
50+
"20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa
51+
"WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5",
52+
"WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa
53+
"WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5",
54+
"M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa
55+
],
56+
"12": ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5",
57+
"20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"],
58+
}
59+
60+
scores = {
61+
"Dataset": [],
62+
"Tomogram": [],
63+
"Dice": []
64+
}
65+
for ds, test_tomos in test_tomograms.items():
66+
ds_name = "01_hoi_maus_2020_incomplete" if ds == "01" else "12_chemical_fix_cryopreparation"
67+
for tomo in tqdm(test_tomos):
68+
seg_path = os.path.join(seg_root, ds_name, tomo)
69+
gt_path = os.path.join(gt_root, ds_name, tomo)
70+
score = eval_az(seg_path, gt_path, seg_key="AZ/thin_az", gt_key="labels_pp/filtered_az")
71+
72+
scores["Dataset"].append(ds_name)
73+
scores["Tomogram"].append(tomo)
74+
scores["Dice"].append(score)
75+
76+
scores = pd.DataFrame(scores)
77+
scores.to_excel(res_path, index=False)
78+
79+
else:
80+
scores = pd.read_excel(res_path)
81+
82+
print("Evaluation for the datasets:")
83+
for ds in pd.unique(scores.Dataset):
84+
print(ds)
85+
ds_scores = scores[scores.Dataset == ds]["Dice"]
86+
print(ds_scores.mean(), "+-", ds_scores.std())
87+
88+
print("Total:")
89+
print(scores["Dice"].mean(), "+-", scores["Dice"].std())
90+
91+
92+
main()

0 commit comments

Comments
 (0)