Skip to content

Commit 168318f

Browse files
First working version of validation functionality, not fully tested
1 parent 1c58586 commit 168318f

File tree

4 files changed

+167
-30
lines changed

4 files changed

+167
-30
lines changed

flamingo_tools/validation.py

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import pandas as pd
88
import zarr
99

10+
from scipy.ndimage import distance_transform_edt
11+
from scipy.optimize import linear_sum_assignment
12+
from skimage.measure import regionprops_table
13+
from skimage.segmentation import relabel_sequential
14+
from tqdm import tqdm
15+
1016
from .s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
1117

1218

@@ -21,16 +27,26 @@ def _normalize_cochlea_name(name):
2127
return f"{prefix}_{number:06d}_{postfix}"
2228

2329

24-
# For a less naive annotation we may need to also fetch +- a few slices,
25-
# so that we have a bit of tolerance with the distance based matching.
30+
# TODO enable table component filtering with MoBIE table
2631
def fetch_data_for_evaluation(
2732
annotation_path: str,
2833
cache_path: Optional[str] = None,
2934
seg_name: str = "SGN",
35+
z_extent: int = 0,
3036
) -> Tuple[np.ndarray, pd.DataFrame]:
3137
"""
3238
"""
39+
# Load the annotations and normalize them for the given z-extent.
3340
annotations = pd.read_csv(annotation_path)
41+
annotations = annotations.drop(columns="index")
42+
if z_extent == 0: # If we don't have a z-extent then we just drop the first axis and rename the other two.
43+
annotations = annotations.drop(columns="axis-0")
44+
annotations = annotations.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
45+
else: # Otherwise we have to center the first axis.
46+
# TODO
47+
raise NotImplementedError
48+
49+
# Load the segmentaiton from cache path if it is given and if it is already cached.
3450
if cache_path is not None and os.path.exists(cache_path):
3551
segmentation = imageio.imread(cache_path)
3652
return segmentation, annotations
@@ -45,10 +61,17 @@ def fetch_data_for_evaluation(
4561
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr")
4662
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
4763

48-
# Download the segmentation for this slice.
64+
# Compute the roi for the given z-extent.
65+
if z_extent == 0:
66+
roi = slice_id
67+
else:
68+
roi = slice(slice_id - z_extent, slice_id + z_extent)
69+
70+
# Download the segmentation for this slice and the given z-extent.
4971
input_key = "s0"
5072
with zarr.open(s3_store, mode="r") as f:
51-
segmentation = f[input_key][slice_id]
73+
segmentation = f[input_key][roi]
74+
segmentation, _, _ = relabel_sequential(segmentation)
5275

5376
# Cache it if required.
5477
if cache_path is not None:
@@ -57,7 +80,61 @@ def fetch_data_for_evaluation(
5780
return segmentation, annotations
5881

5982

60-
def evaluate_annotated_slice(
83+
def compute_matches_for_annotated_slice(
84+
segmentation: np.typing.ArrayLike,
85+
annotations: pd.DataFrame,
86+
matching_tolerance: float = 0.0,
87+
) -> Dict[str, np.ndarray]:
88+
"""Computes the ids of matches and non-matches for a annotated validation slice.
89+
90+
Computes true positive ids (for objects and annotations), false positive ids and false negative ids
91+
by solving a linear cost assignment of distances between objects and annotations.
92+
93+
Args:
94+
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
95+
annotations: The annotations, marking cell centers.
96+
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
97+
98+
Returns:
99+
A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
100+
"""
101+
assert segmentation.ndim in (2, 3)
102+
segmentation_ids = np.unique(segmentation)[1:]
103+
n_objects, n_annotations = len(segmentation_ids), len(annotations)
104+
105+
# In order to get the full distance matrix, we compute the distance to all objects for each annotation.
106+
# This is not very efficient, but it's the most straight-forward and most rigorous approach.
107+
scores = np.zeros((n_objects, n_annotations), dtype="float")
108+
coordinates = ["axis-0", "axis-1"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
109+
for i, row in tqdm(annotations.iterrows(), total=n_annotations, desc="Compute pairwise distances"):
110+
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
111+
distance_input = np.ones(segmentation.shape, dtype="bool")
112+
distance_input[coordinate] = False
113+
distances, indices = distance_transform_edt(distance_input, return_indices=True)
114+
115+
props = regionprops_table(segmentation, intensity_image=distances, properties=("label", "min_intensity"))
116+
distances = props["min_intensity"]
117+
assert len(distances) == scores.shape[0]
118+
scores[:, i] = distances
119+
120+
# Find the assignment of points to objects.
121+
# These correspond to the TP ids in the point / object annotations.
122+
tp_ids_objects, tp_ids_annotations = linear_sum_assignment(scores)
123+
match_ok = scores[tp_ids_objects, tp_ids_annotations] <= matching_tolerance
124+
tp_ids_objects, tp_ids_annotations = tp_ids_objects[match_ok], tp_ids_annotations[match_ok]
125+
tp_ids_objects = segmentation_ids[tp_ids_objects]
126+
assert len(tp_ids_objects) == len(tp_ids_annotations)
127+
128+
# Find the false positives: objects that are not part of the matches.
129+
fp_ids = np.setdiff1d(segmentation_ids, tp_ids_objects)
130+
131+
# Find the false negatives: annotations that are not part of the matches.
132+
fn_ids = np.setdiff1d(np.arange(n_annotations), tp_ids_annotations)
133+
134+
return {"tp_objects": tp_ids_objects, "tp_annotations": tp_ids_annotations, "fp": fp_ids, "fn": fn_ids}
135+
136+
137+
def compute_scores_for_annotated_slice(
61138
segmentation: np.typing.ArrayLike,
62139
annotations: pd.DataFrame,
63140
matching_tolerance: float = 0.0,
@@ -67,20 +144,39 @@ def evaluate_annotated_slice(
67144
Computes true positives, false positives and false negatives for scoring.
68145
69146
Args:
70-
segmentation: The segmentation for this slide.
147+
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
71148
annotations: The annotations, marking cell centers.
72-
matching_tolerance: ...
149+
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
73150
74151
Returns:
75152
A dictionary with keys 'tp', 'fp' and 'fn', mapping to the respective counts.
76153
"""
77-
# Compute the distance transform and nearest id fields.
154+
result = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance)
78155

79-
# Match all of the points to segmented objects based on their distance.
156+
# To determine the TPs, FPs and FNs.
157+
tp = len(result["tp_objects"])
158+
fp = len(result["fp"])
159+
fn = len(result["fn"])
160+
return {"tp": tp, "fp": fp, "fn": fn}
80161

81-
# Determine the TPs, FPs and FNs based on a linear cost assignment.
82-
tp = ...
83-
fp = ...
84-
fn = ...
85162

86-
return {"tp": tp, "fp": fp, "fn": fn}
163+
def for_visualization(segmentation, annotations, matches):
164+
green_red = ["#00FF00", "#FF0000"]
165+
166+
seg_vis = np.zeros_like(segmentation)
167+
tps, fps = matches["tp_objects"], matches["fp"]
168+
seg_vis[np.isin(segmentation, tps)] = 1
169+
seg_vis[np.isin(segmentation, fps)] = 2
170+
171+
# TODO red / green colormap
172+
seg_props = dict(color={1: green_red[0], 2: green_red[1]})
173+
174+
point_vis = annotations.copy()
175+
tps = matches["tp_annotations"]
176+
point_props = dict(
177+
properties={"match": [0 if aid in tps else 1 for aid in range(len(annotations))]},
178+
border_color="match",
179+
border_color_cycle=green_red,
180+
)
181+
182+
return seg_vis, point_vis, seg_props, point_props
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
3+
import imageio.v3 as imageio
4+
import napari
5+
import pandas as pd
6+
7+
# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation"
8+
ROOT = "annotation_data"
9+
TEST_ANNOTATION = os.path.join(ROOT, "AnnotationsEK/MAMD58L_PV_z771_base_full_annotationsEK.csv")
10+
11+
12+
def check_annotation(image_path, annotation_path):
13+
annotations = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values
14+
15+
image = imageio.imread(image_path)
16+
v = napari.Viewer()
17+
v.add_image(image)
18+
v.add_points(annotations)
19+
napari.run()
20+
21+
22+
def main():
23+
check_annotation(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"), TEST_ANNOTATION)
24+
25+
26+
if __name__ == "__main__":
27+
main()

scripts/validation/develop_f1_val.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
3+
import imageio.v3 as imageio
4+
import napari
5+
6+
from flamingo_tools.validation import fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization
7+
8+
# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation"
9+
ROOT = "annotation_data"
10+
TEST_ANNOTATION = os.path.join(ROOT, "AnnotationsEK/MAMD58L_PV_z771_base_full_annotationsEK.csv")
11+
12+
13+
def main():
14+
segmentation, annotations = fetch_data_for_evaluation(TEST_ANNOTATION, cache_path="./seg.tif")
15+
matches = compute_matches_for_annotated_slice(segmentation, annotations)
16+
vis_segmentation, vis_points, seg_props, point_props = for_visualization(segmentation, annotations, matches)
17+
18+
image = imageio.imread(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"))
19+
20+
v = napari.Viewer()
21+
v.add_image(image)
22+
v.add_labels(vis_segmentation, **seg_props)
23+
v.add_points(vis_points, **point_props)
24+
v.add_labels(segmentation, visible=False)
25+
v.add_points(annotations, visible=False)
26+
napari.run()
27+
28+
29+
if __name__ == "__main__":
30+
main()

0 commit comments

Comments
 (0)