Skip to content

Commit c1530f2

Browse files
Simplified working implementation of validation functionality
1 parent ab87fb9 commit c1530f2

File tree

2 files changed

+96
-30
lines changed

2 files changed

+96
-30
lines changed

flamingo_tools/validation.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,59 @@ def fetch_data_for_evaluation(
9797
return segmentation, annotations
9898

9999

100-
# TODO crop to the bounding box around the union of points and segmentation masks to be more efficient.
100+
# We should use the hungarian based matching, but I can't find the bug in it right now.
101+
def _naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
102+
distances, indices = distance_transform_edt(segmentation == 0, return_indices=True)
103+
104+
matched_ids = {}
105+
matched_distances = {}
106+
annotation_id = 0
107+
for _, row in annotations.iterrows():
108+
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
109+
object_distance = distances[coordinate]
110+
if object_distance <= matching_tolerance:
111+
closest_object_coord = tuple(idx[coordinate] for idx in indices)
112+
object_id = segmentation[closest_object_coord]
113+
if object_id not in matched_ids or matched_distances[object_id] > object_distance:
114+
matched_ids[object_id] = annotation_id
115+
matched_distances[object_id] = object_distance
116+
annotation_id += 1
117+
118+
tp_ids_objects = np.array(list(matched_ids.keys()))
119+
tp_ids_annotations = np.array(list(matched_ids.values()))
120+
return tp_ids_objects, tp_ids_annotations
121+
122+
123+
# There is a bug in here that neither I nor o3 can figure out ...
124+
def _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
125+
n_objects, n_annotations = len(segmentation_ids), len(annotations)
126+
127+
# In order to get the full distance matrix, we compute the distance to all objects for each annotation.
128+
# This is not very efficient, but it's the most straight-forward and most rigorous approach.
129+
scores = np.zeros((n_objects, n_annotations), dtype="float")
130+
i = 0
131+
for _, row in tqdm(annotations.iterrows(), total=n_annotations, desc="Compute pairwise distances"):
132+
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
133+
distance_input = np.ones(segmentation.shape, dtype="bool")
134+
distance_input[coordinate] = False
135+
distances = distance_transform_edt(distance_input)
136+
137+
props = regionprops_table(segmentation, intensity_image=distances, properties=("label", "min_intensity"))
138+
distances = props["min_intensity"]
139+
assert len(distances) == scores.shape[0]
140+
scores[:, i] = distances
141+
i += 1
142+
143+
# Find the assignment of points to objects.
144+
# These correspond to the TP ids in the point / object annotations.
145+
tp_ids_objects, tp_ids_annotations = linear_sum_assignment(scores)
146+
match_ok = scores[tp_ids_objects, tp_ids_annotations] <= matching_tolerance
147+
tp_ids_objects, tp_ids_annotations = tp_ids_objects[match_ok], tp_ids_annotations[match_ok]
148+
tp_ids_objects = segmentation_ids[tp_ids_objects]
149+
150+
return tp_ids_objects, tp_ids_annotations
151+
152+
101153
def compute_matches_for_annotated_slice(
102154
segmentation: np.typing.ArrayLike,
103155
annotations: pd.DataFrame,
@@ -117,37 +169,35 @@ def compute_matches_for_annotated_slice(
117169
A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
118170
"""
119171
assert segmentation.ndim in (2, 3)
120-
segmentation_ids = np.unique(segmentation)[1:]
121-
n_objects, n_annotations = len(segmentation_ids), len(annotations)
122-
123-
# In order to get the full distance matrix, we compute the distance to all objects for each annotation.
124-
# This is not very efficient, but it's the most straight-forward and most rigorous approach.
125-
scores = np.zeros((n_objects, n_annotations), dtype="float")
126172
coordinates = ["axis-0", "axis-1"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
127-
for i, row in tqdm(annotations.iterrows(), total=n_annotations, desc="Compute pairwise distances"):
128-
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
129-
distance_input = np.ones(segmentation.shape, dtype="bool")
130-
distance_input[coordinate] = False
131-
distances, indices = distance_transform_edt(distance_input, return_indices=True)
132-
133-
props = regionprops_table(segmentation, intensity_image=distances, properties=("label", "min_intensity"))
134-
distances = props["min_intensity"]
135-
assert len(distances) == scores.shape[0]
136-
scores[:, i] = distances
173+
segmentation_ids = np.unique(segmentation)[1:]
137174

138-
# Find the assignment of points to objects.
139-
# These correspond to the TP ids in the point / object annotations.
140-
tp_ids_objects, tp_ids_annotations = linear_sum_assignment(scores)
141-
match_ok = scores[tp_ids_objects, tp_ids_annotations] <= matching_tolerance
142-
tp_ids_objects, tp_ids_annotations = tp_ids_objects[match_ok], tp_ids_annotations[match_ok]
143-
tp_ids_objects = segmentation_ids[tp_ids_objects]
175+
# Crop to the minimal enclosing bounding box of points and segmented objects.
176+
bb_seg = np.where(segmentation != 0)
177+
bb_seg = tuple(slice(int(bb.min()), int(bb.max())) for bb in bb_seg)
178+
bb_points = tuple(
179+
slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1)
180+
for coords in coordinates
181+
)
182+
bbox = tuple(slice(min(bbs.start, bbp.start), max(bbs.stop, bbp.stop)) for bbs, bbp in zip(bb_seg, bb_points))
183+
segmentation = segmentation[bbox]
184+
185+
annotations = annotations.copy()
186+
for coord, bb in zip(coordinates, bbox):
187+
annotations[coord] -= bb.start
188+
assert (annotations[coord] <= bb.stop).all()
189+
190+
# tp_ids_objects, tp_ids_annotations =\
191+
# _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
192+
tp_ids_objects, tp_ids_annotations =\
193+
_naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
144194
assert len(tp_ids_objects) == len(tp_ids_annotations)
145195

146196
# Find the false positives: objects that are not part of the matches.
147197
fp_ids = np.setdiff1d(segmentation_ids, tp_ids_objects)
148198

149199
# Find the false negatives: annotations that are not part of the matches.
150-
fn_ids = np.setdiff1d(np.arange(n_annotations), tp_ids_annotations)
200+
fn_ids = np.setdiff1d(np.arange(len(annotations)), tp_ids_annotations)
151201

152202
return {"tp_objects": tp_ids_objects, "tp_annotations": tp_ids_annotations, "fp": fp_ids, "fn": fn_ids}
153203

@@ -186,15 +236,19 @@ def for_visualization(segmentation, annotations, matches):
186236
seg_vis[np.isin(segmentation, tps)] = 1
187237
seg_vis[np.isin(segmentation, fps)] = 2
188238

189-
# TODO red / green colormap
190-
seg_props = dict(color={1: green_red[0], 2: green_red[1]})
239+
seg_props = dict(colormap={1: green_red[0], 2: green_red[1]})
191240

192241
point_vis = annotations.copy()
193242
tps = matches["tp_annotations"]
194243
point_props = dict(
195-
properties={"match": [0 if aid in tps else 1 for aid in range(len(annotations))]},
196-
border_color="match",
197-
border_color_cycle=green_red,
244+
properties={
245+
"id": list(range(len(annotations))),
246+
"match": ["tp" if aid in tps else "fn" for aid in range(len(annotations))]
247+
},
248+
face_color="match",
249+
face_color_cycle=green_red[::-1],
250+
border_width=0.25,
251+
size=10,
198252
)
199253

200254
return seg_vis, point_vis, seg_props, point_props

scripts/validation/visualize_validation.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,23 @@
1111

1212

1313
def main():
14+
image = imageio.imread(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"))
1415
segmentation, annotations = fetch_data_for_evaluation(TEST_ANNOTATION, cache_path="./seg.tif")
16+
17+
# v = napari.Viewer()
18+
# v.add_image(image)
19+
# v.add_labels(segmentation)
20+
# v.add_points(annotations)
21+
# napari.run()
22+
1523
matches = compute_matches_for_annotated_slice(segmentation, annotations)
24+
tps, fns = matches["tp_annotations"], matches["fn"]
1625
vis_segmentation, vis_points, seg_props, point_props = for_visualization(segmentation, annotations, matches)
1726

18-
image = imageio.imread(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"))
27+
print("True positive annotations:")
28+
print(tps)
29+
print("False negative annotations:")
30+
print(fns)
1931

2032
v = napari.Viewer()
2133
v.add_image(image)

0 commit comments

Comments
 (0)