Skip to content

Commit f76dbf7

Browse files
Add attempt at compartment evaluation code
1 parent d8d7c1b commit f76dbf7

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import pandas as pd
5+
6+
from synapse_net.inference.inference import get_model
7+
from synapse_net.inference.compartments import segment_compartments
8+
from skimage.segmentation import find_boundaries
9+
10+
from elf.evaluation.matching import matching
11+
12+
from train_compartments import get_paths_3d
13+
from sklearn.model_selection import train_test_split
14+
15+
16+
def run_prediction(paths):
17+
output_folder = "./compartment_eval"
18+
os.makedirs(output_folder, exist_ok=True)
19+
20+
model = get_model("compartments")
21+
for path in paths:
22+
with h5py.File(path, "r") as f:
23+
input_vol = f["raw"][:]
24+
seg, pred = segment_compartments(input_vol, model=model, return_predictions=True)
25+
fname = os.path.basename(path)
26+
out = os.path.join(output_folder, fname)
27+
with h5py.File(out, "a") as f:
28+
f.create_dataset("seg", data=seg, compression="gzip")
29+
f.create_dataset("pred", data=pred, compression="gzip")
30+
31+
32+
def binary_recall(gt, pred):
33+
tp = np.logical_and(gt, pred).sum()
34+
fn = np.logical_and(gt, ~pred).sum()
35+
return float(tp) / (tp + fn) if (tp + fn) else 0.0
36+
37+
38+
def run_evaluation(paths):
39+
output_folder = "./compartment_eval"
40+
41+
results = {
42+
"name": [],
43+
"recall-pred": [],
44+
"recall-seg": [],
45+
}
46+
47+
for path in paths:
48+
with h5py.File(path, "r") as f:
49+
labels = f["labels/compartments"][:]
50+
boundary_labels = find_boundaries(labels).astype("bool")
51+
52+
fname = os.path.basename(path)
53+
out = os.path.join(output_folder, fname)
54+
with h5py.File(out, "a") as f:
55+
seg, pred = f["seg"][:], f["pred"][:]
56+
57+
recall_pred = binary_recall(boundary_labels, pred > 0.5)
58+
recall_seg = matching(seg, labels)["recall"]
59+
60+
results["name"].append(fname)
61+
results["recall-pred"].append(recall_pred)
62+
results["recall-seg"].append(recall_seg)
63+
64+
results = pd.DataFrame(results)
65+
print(results)
66+
print(results[["recall-pred", "recall-seg"]].mean())
67+
68+
69+
def check_predictions(paths):
70+
import napari
71+
output_folder = "./compartment_eval"
72+
73+
for path in paths:
74+
with h5py.File(path, "r") as f:
75+
raw = f["raw"][:]
76+
labels = f["labels/compartments"][:]
77+
boundary_labels = find_boundaries(labels)
78+
79+
fname = os.path.basename(path)
80+
out = os.path.join(output_folder, fname)
81+
with h5py.File(out, "a") as f:
82+
seg, pred = f["seg"][:], f["pred"][:]
83+
84+
v = napari.Viewer()
85+
v.add_image(raw)
86+
v.add_image(pred)
87+
v.add_labels(labels)
88+
v.add_labels(boundary_labels)
89+
v.add_labels(seg)
90+
napari.run()
91+
92+
93+
def main():
94+
paths = get_paths_3d()
95+
_, val_paths = train_test_split(paths, test_size=0.10, random_state=42)
96+
97+
# run_prediction(val_paths)
98+
run_evaluation(val_paths)
99+
# check_predictions(val_paths)
100+
101+
102+
if __name__ == "__main__":
103+
main()

scripts/cooper/training/train_compartments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from synapse_net.training import supervised_training
1515

1616
TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/compartments"
17-
# TRAIN_ROOT = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/ground_truth/compartments/output/compartment_gt" # noqa
1817

1918

2019
def get_paths_2d():

0 commit comments

Comments
 (0)