Skip to content

Commit 39d69d5

Browse files
Implement AZ evaluation WIP
1 parent 7dd6962 commit 39d69d5

File tree

3 files changed

+167
-5
lines changed

3 files changed

+167
-5
lines changed
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
from synapse_net.sample_data import get_sample_data
2-
from synapse_net.inference import run_segmentation, get_model
32
from elf.io import open_file
43

54

65
sample_data = get_sample_data("tem_tomo")
76
tomo = open_file(sample_data, "r")["data"][:]
87

9-
model = get_model("active_zone")
10-
seg = run_segmentation(tomo, model, "active_zone")
118

12-
with open_file("./pred.h5", "a") as f:
13-
f.create_dataset("pred", data=seg, compression="gzip")
9+
def run_prediction():
10+
from synapse_net.inference import run_segmentation, get_model
11+
12+
model = get_model("active_zone")
13+
seg = run_segmentation(tomo, model, "active_zone")
14+
15+
with open_file("./pred.h5", "a") as f:
16+
f.create_dataset("pred", data=seg, compression="gzip")
17+
18+
19+
def check_prediction():
20+
import napari
21+
22+
with open_file("./pred.h5", "r") as f:
23+
pred = f["pred"][:]
24+
25+
v = napari.Viewer()
26+
v.add_image(tomo)
27+
v.add_labels(pred)
28+
napari.run()
29+
30+
31+
check_prediction()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import argparse
2+
import os
3+
from glob import glob
4+
5+
6+
def _get_paths(seg_root, gt_root, image_root=None):
7+
seg_paths = sorted(glob(os.path.join(seg_root, "**/*.h5"), recursive=True))
8+
gt_paths = sorted(glob(os.path.join(gt_root, "**/*.h5"), recursive=True))
9+
assert len(seg_paths) == len(gt_paths)
10+
11+
if image_root is None:
12+
image_paths = [None] * len(seg_paths)
13+
else:
14+
image_paths = sorted(glob(os.path.join(image_root, "**/*.mrc"), recursive=True))
15+
assert len(image_paths) == len(seg_paths)
16+
17+
return seg_paths, gt_paths, image_paths
18+
19+
20+
# TODO extend this
21+
def run_az_evaluation(args):
22+
from synapse_net.ground_truth.az_evaluation import az_evaluation
23+
24+
seg_paths, gt_paths, _ = _get_paths(args.seg_root, args.gt_root)
25+
result = az_evaluation(seg_paths, gt_paths, seg_key="seg", gt_key="gt")
26+
27+
print(result)
28+
29+
30+
def visualize_az_evaluation(args):
31+
from elf.visualisation.metric_visualization import run_metric_visualization
32+
from synapse_net.ground_truth.az_evaluation import _postprocess
33+
from elf.io import open_file
34+
35+
seg_paths, gt_paths, image_paths = _get_paths(args.seg_root, args.gt_root, args.image_root)
36+
for seg_path, gt_path, image_path in zip(seg_paths, gt_paths, image_paths):
37+
image = None if image_path is None else open_file(image_path, "r")["data"][:]
38+
39+
with open_file(seg_path, "r") as f:
40+
seg = f["seg"][:]
41+
with open_file(gt_path, "r") as f:
42+
gt = f["gt"][:]
43+
44+
seg = _postprocess(seg, apply_cc=True, min_component_size=100)
45+
gt = _postprocess(gt, apply_cc=True, min_component_size=100)
46+
47+
run_metric_visualization(image, seg, gt)
48+
49+
50+
def main():
51+
parser = argparse.ArgumentParser()
52+
parser.add_argument("-s", "--seg_root", required=True)
53+
parser.add_argument("-g", "--gt_root", required=True)
54+
parser.add_argument("-i", "--image_root")
55+
parser.add_argument("--visualize", action="store_true")
56+
args = parser.parse_args()
57+
58+
if args.visualize:
59+
visualize_az_evaluation(args)
60+
else:
61+
run_az_evaluation(args)
62+
63+
64+
if __name__ == "__main__":
65+
main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
from typing import List
3+
4+
import h5py
5+
import pandas as pd
6+
import numpy as np
7+
8+
from elf.evaluation.matching import _compute_scores, _compute_tps
9+
from elf.evaluation import dice_score
10+
from skimage.measure import label
11+
from tqdm import tqdm
12+
13+
14+
def _postprocess(data, apply_cc, min_component_size):
15+
if apply_cc:
16+
data = label(data)
17+
ids, sizes = np.unique(data, return_counts=True)
18+
filter_ids = ids[sizes < min_component_size]
19+
data[np.isin(data, filter_ids)] = 0
20+
return data
21+
22+
23+
def _single_az_evaluation(seg, gt, apply_cc, min_component_size):
24+
assert seg.shape == gt.shape, f"{seg.shape}, {gt.shape}"
25+
seg = _postprocess(seg, apply_cc, min_component_size)
26+
gt = _postprocess(gt, apply_cc, min_component_size)
27+
28+
dice = dice_score(seg > 0, gt > 0)
29+
30+
n_true, n_matched, n_pred, scores = _compute_scores(seg, gt, criterion="iou", ignore_label=0)
31+
tp = _compute_tps(scores, n_matched, threshold=0.5)
32+
fp = n_pred - tp
33+
fn = n_true - tp
34+
35+
return {"tp": tp, "fp": fp, "fn": fn, "dice": dice}
36+
37+
38+
# TODO further post-processing?
39+
def az_evaluation(
40+
seg_paths: List[str],
41+
gt_paths: List[str],
42+
seg_key: str,
43+
gt_key: str,
44+
apply_cc: bool = True,
45+
min_component_size: int = 100, # TODO
46+
) -> pd.DataFrame:
47+
"""Evaluate active zone segmentations against ground-truth annotations.
48+
49+
Args:
50+
seg_paths: The filepaths to the segmentations, stored as hd5 files.
51+
gt_paths: The filepaths to the ground-truth annotatons, stored as hdf5 files.
52+
seg_key: The internal path to the data in the segmentation hdf5 file.
53+
gt_key: The internal path to the data in the ground-truth hdf5 file.
54+
apply_cc: Whether to apply connected components before evaluation.
55+
min_component_size: Minimum component size for filtering the segmentation and annotations before evaluation.
56+
57+
Returns:
58+
A data frame with the evaluation results per tomogram.
59+
"""
60+
assert len(seg_paths) == len(gt_paths)
61+
62+
results = {
63+
"tomo_name": [],
64+
"tp": [],
65+
"fp": [],
66+
"fn": [],
67+
"dice": [],
68+
}
69+
for seg_path, gt_path in tqdm(zip(seg_paths, gt_paths), total=len(seg_paths), desc="Run AZ Eval"):
70+
with h5py.File(seg_path, "r") as f:
71+
seg = f[seg_key][:]
72+
with h5py.File(gt_path, "r") as f:
73+
gt = f[gt_key][:]
74+
# TODO more post-processing params
75+
result = _single_az_evaluation(seg, gt, apply_cc, min_component_size)
76+
results["tomo_name"].append(os.path.basename(seg_path))
77+
for res in ("tp", "fp", "fn", "dice"):
78+
results[res].append(result[res])
79+
return pd.DataFrame(results)

0 commit comments

Comments
 (0)