Skip to content

Commit 3e26f80

Browse files
Merge branch 'master' into domain-adaptation
2 parents 8d03fe2 + 05e3912 commit 3e26f80

File tree

6 files changed

+462
-10
lines changed

6 files changed

+462
-10
lines changed

flamingo_tools/validation.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import os
2+
import re
3+
from typing import Dict, List, Optional, Tuple
4+
5+
import imageio.v3 as imageio
6+
import numpy as np
7+
import pandas as pd
8+
import zarr
9+
10+
from scipy.ndimage import distance_transform_edt
11+
from scipy.optimize import linear_sum_assignment
12+
from skimage.measure import regionprops_table
13+
from skimage.segmentation import relabel_sequential
14+
from tqdm import tqdm
15+
16+
from .s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
17+
18+
19+
def _normalize_cochlea_name(name):
20+
match = re.search(r"\d+", name)
21+
pos = match.start() if match else None
22+
assert pos is not None, name
23+
prefix = name[:pos]
24+
prefix = f"{prefix[0]}_{prefix[1:]}"
25+
number = int(name[pos:-1])
26+
postfix = name[-1]
27+
return f"{prefix}_{number:06d}_{postfix}"
28+
29+
30+
def parse_annotation_path(annotation_path):
31+
fname = os.path.basename(annotation_path)
32+
name_parts = fname.split("_")
33+
cochlea = _normalize_cochlea_name(name_parts[0])
34+
slice_id = int(name_parts[2][1:])
35+
return cochlea, slice_id
36+
37+
38+
# TODO enable table component filtering with MoBIE table
39+
# NOTE: the main component is always #1
40+
def fetch_data_for_evaluation(
41+
annotation_path: str,
42+
cache_path: Optional[str] = None,
43+
seg_name: str = "SGN",
44+
z_extent: int = 0,
45+
components_for_postprocessing: Optional[List[int]] = None,
46+
) -> Tuple[np.ndarray, pd.DataFrame]:
47+
"""
48+
"""
49+
# Load the annotations and normalize them for the given z-extent.
50+
annotations = pd.read_csv(annotation_path)
51+
annotations = annotations.drop(columns="index")
52+
if z_extent == 0: # If we don't have a z-extent then we just drop the first axis and rename the other two.
53+
annotations = annotations.drop(columns="axis-0")
54+
annotations = annotations.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
55+
else: # Otherwise we have to center the first axis.
56+
# TODO
57+
raise NotImplementedError
58+
59+
# Load the segmentaiton from cache path if it is given and if it is already cached.
60+
if cache_path is not None and os.path.exists(cache_path):
61+
segmentation = imageio.imread(cache_path)
62+
return segmentation, annotations
63+
64+
# Parse which ID and which cochlea from the name.
65+
cochlea, slice_id = parse_annotation_path(annotation_path)
66+
67+
# Open the S3 connection, get the path to the SGN segmentation in S3.
68+
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr")
69+
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
70+
71+
# Compute the roi for the given z-extent.
72+
if z_extent == 0:
73+
roi = slice_id
74+
else:
75+
roi = slice(slice_id - z_extent, slice_id + z_extent)
76+
77+
# Download the segmentation for this slice and the given z-extent.
78+
input_key = "s0"
79+
with zarr.open(s3_store, mode="r") as f:
80+
segmentation = f[input_key][roi]
81+
82+
if components_for_postprocessing is not None:
83+
# Filter the IDs so that only the ones part of 'components_for_postprocessing_remain'.
84+
85+
# First, we download the MoBIE table for this segmentation.
86+
internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv")
87+
with fs.open(internal_path, "r") as f:
88+
table = pd.read_csv(f, sep="\t")
89+
90+
# Then we get the ids for the components and us them to filter the segmentation.
91+
component_mask = np.isin(table.component_labels.values, components_for_postprocessing)
92+
keep_label_ids = table.label_id.values[component_mask].astype("int64")
93+
filter_mask = ~np.isin(segmentation, keep_label_ids)
94+
segmentation[filter_mask] = 0
95+
96+
segmentation, _, _ = relabel_sequential(segmentation)
97+
98+
# Cache it if required.
99+
if cache_path is not None:
100+
imageio.imwrite(cache_path, segmentation, compression="zlib")
101+
102+
return segmentation, annotations
103+
104+
105+
# We should use the hungarian based matching, but I can't find the bug in it right now.
106+
def _naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
107+
distances, indices = distance_transform_edt(segmentation == 0, return_indices=True)
108+
109+
matched_ids = {}
110+
matched_distances = {}
111+
annotation_id = 0
112+
for _, row in annotations.iterrows():
113+
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
114+
object_distance = distances[coordinate]
115+
if object_distance <= matching_tolerance:
116+
closest_object_coord = tuple(idx[coordinate] for idx in indices)
117+
object_id = segmentation[closest_object_coord]
118+
if object_id not in matched_ids or matched_distances[object_id] > object_distance:
119+
matched_ids[object_id] = annotation_id
120+
matched_distances[object_id] = object_distance
121+
annotation_id += 1
122+
123+
tp_ids_objects = np.array(list(matched_ids.keys()))
124+
tp_ids_annotations = np.array(list(matched_ids.values()))
125+
return tp_ids_objects, tp_ids_annotations
126+
127+
128+
# There is a bug in here that neither I nor o3 can figure out ...
129+
def _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
130+
n_objects, n_annotations = len(segmentation_ids), len(annotations)
131+
132+
# In order to get the full distance matrix, we compute the distance to all objects for each annotation.
133+
# This is not very efficient, but it's the most straight-forward and most rigorous approach.
134+
scores = np.zeros((n_objects, n_annotations), dtype="float")
135+
i = 0
136+
for _, row in tqdm(annotations.iterrows(), total=n_annotations, desc="Compute pairwise distances"):
137+
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
138+
distance_input = np.ones(segmentation.shape, dtype="bool")
139+
distance_input[coordinate] = False
140+
distances = distance_transform_edt(distance_input)
141+
142+
props = regionprops_table(segmentation, intensity_image=distances, properties=("label", "min_intensity"))
143+
distances = props["min_intensity"]
144+
assert len(distances) == scores.shape[0]
145+
scores[:, i] = distances
146+
i += 1
147+
148+
# Find the assignment of points to objects.
149+
# These correspond to the TP ids in the point / object annotations.
150+
tp_ids_objects, tp_ids_annotations = linear_sum_assignment(scores)
151+
match_ok = scores[tp_ids_objects, tp_ids_annotations] <= matching_tolerance
152+
tp_ids_objects, tp_ids_annotations = tp_ids_objects[match_ok], tp_ids_annotations[match_ok]
153+
tp_ids_objects = segmentation_ids[tp_ids_objects]
154+
155+
return tp_ids_objects, tp_ids_annotations
156+
157+
158+
def compute_matches_for_annotated_slice(
159+
segmentation: np.typing.ArrayLike,
160+
annotations: pd.DataFrame,
161+
matching_tolerance: float = 0.0,
162+
) -> Dict[str, np.ndarray]:
163+
"""Computes the ids of matches and non-matches for a annotated validation slice.
164+
165+
Computes true positive ids (for objects and annotations), false positive ids and false negative ids
166+
by solving a linear cost assignment of distances between objects and annotations.
167+
168+
Args:
169+
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
170+
annotations: The annotations, marking cell centers.
171+
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
172+
173+
Returns:
174+
A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
175+
"""
176+
assert segmentation.ndim in (2, 3)
177+
coordinates = ["axis-0", "axis-1"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
178+
segmentation_ids = np.unique(segmentation)[1:]
179+
180+
# Crop to the minimal enclosing bounding box of points and segmented objects.
181+
bb_seg = np.where(segmentation != 0)
182+
bb_seg = tuple(slice(int(bb.min()), int(bb.max())) for bb in bb_seg)
183+
bb_points = tuple(
184+
slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1)
185+
for coords in coordinates
186+
)
187+
bbox = tuple(slice(min(bbs.start, bbp.start), max(bbs.stop, bbp.stop)) for bbs, bbp in zip(bb_seg, bb_points))
188+
segmentation = segmentation[bbox]
189+
190+
annotations = annotations.copy()
191+
for coord, bb in zip(coordinates, bbox):
192+
annotations[coord] -= bb.start
193+
assert (annotations[coord] <= bb.stop).all()
194+
195+
# tp_ids_objects, tp_ids_annotations =\
196+
# _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
197+
tp_ids_objects, tp_ids_annotations =\
198+
_naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
199+
assert len(tp_ids_objects) == len(tp_ids_annotations)
200+
201+
# Find the false positives: objects that are not part of the matches.
202+
fp_ids = np.setdiff1d(segmentation_ids, tp_ids_objects)
203+
204+
# Find the false negatives: annotations that are not part of the matches.
205+
fn_ids = np.setdiff1d(np.arange(len(annotations)), tp_ids_annotations)
206+
207+
return {"tp_objects": tp_ids_objects, "tp_annotations": tp_ids_annotations, "fp": fp_ids, "fn": fn_ids}
208+
209+
210+
def compute_scores_for_annotated_slice(
211+
segmentation: np.typing.ArrayLike,
212+
annotations: pd.DataFrame,
213+
matching_tolerance: float = 0.0,
214+
) -> Dict[str, int]:
215+
"""Computes the scores for a annotated validation slice.
216+
217+
Computes true positives, false positives and false negatives for scoring.
218+
219+
Args:
220+
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
221+
annotations: The annotations, marking cell centers.
222+
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
223+
224+
Returns:
225+
A dictionary with keys 'tp', 'fp' and 'fn', mapping to the respective counts.
226+
"""
227+
result = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance)
228+
229+
# To determine the TPs, FPs and FNs.
230+
tp = len(result["tp_objects"])
231+
fp = len(result["fp"])
232+
fn = len(result["fn"])
233+
return {"tp": tp, "fp": fp, "fn": fn}
234+
235+
236+
def for_visualization(segmentation, annotations, matches):
237+
green_red = ["#00FF00", "#FF0000"]
238+
239+
seg_vis = np.zeros_like(segmentation)
240+
tps, fps = matches["tp_objects"], matches["fp"]
241+
seg_vis[np.isin(segmentation, tps)] = 1
242+
seg_vis[np.isin(segmentation, fps)] = 2
243+
244+
seg_props = dict(colormap={1: green_red[0], 2: green_red[1]})
245+
246+
point_vis = annotations.copy()
247+
tps = matches["tp_annotations"]
248+
match_properties = ["tp" if aid in tps else "fn" for aid in range(len(annotations))]
249+
# The color cycle assigns the first color to the first property etc.
250+
# So we need to set the first color to red if the first id is a false negative and vice versa.
251+
color_cycle = green_red[::-1] if match_properties[0] == "fn" else green_red
252+
point_props = dict(
253+
properties={
254+
"id": list(range(len(annotations))),
255+
"match": match_properties,
256+
},
257+
face_color="match",
258+
face_color_cycle=color_cycle,
259+
border_width=0.25,
260+
size=10,
261+
)
262+
263+
return seg_vis, point_vis, seg_props, point_props

scripts/extract_block.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from typing import Optional, List
77

8+
import imageio.v3 as imageio
89
import numpy as np
910
import zarr
1011

@@ -16,8 +17,10 @@ def main(
1617
coords: List[int],
1718
output_dir: str = None,
1819
input_key: str = "setup0/timepoint0/s0",
20+
output_key: Optional[str] = None,
1921
resolution: float = 0.38,
2022
roi_halo: List[int] = [128, 128, 64],
23+
tif: bool = False,
2124
s3: Optional[bool] = False,
2225
s3_credentials: Optional[str] = None,
2326
s3_bucket_name: Optional[str] = None,
@@ -30,14 +33,17 @@ def main(
3033
input_path: Input folder in n5 / ome-zarr format.
3134
coords: Center coordinates of extracted 3D volume.
3235
output_dir: Output directory for saving output as <basename>_crop.n5. Default: input directory.
36+
input_key: Input key for data in input file.
37+
output_key: Output key for data in n5 output or used as suffix for tif output.
3338
roi_halo: ROI halo of extracted 3D volume.
39+
tif: Flag for tif output
3440
s3: Flag for considering input_path for S3 bucket.
3541
s3_bucket_name: S3 bucket name.
3642
s3_service_endpoint: S3 service endpoint.
3743
s3_credentials: File path to credentials for S3 bucket.
3844
"""
39-
40-
coord_string = "-".join([str(c) for c in coords])
45+
coords = [int(round(c)) for c in coords]
46+
coord_string = "-".join([str(c).zfill(4) for c in coords])
4147

4248
# Dimensions are inversed to view in MoBIE (x y z) -> (z y x)
4349
coords.reverse()
@@ -46,7 +52,14 @@ def main(
4652
input_content = list(filter(None, input_path.split("/")))
4753

4854
if s3:
49-
basename = input_content[0] + "_" + input_content[-1].split(".")[0]
55+
image_name = input_content[-1].split(".")[0]
56+
if len(image_name.split("_")) > 1:
57+
resized_suffix = "_resized"
58+
image_prefix = image_name.split("_")[0]
59+
else:
60+
resized_suffix = ""
61+
image_prefix = image_name
62+
basename = input_content[0] + resized_suffix
5063
else:
5164
basename = "".join(input_content[-1].split(".")[:-1])
5265

@@ -56,7 +69,16 @@ def main(
5669
if output_dir == "":
5770
output_dir = input_dir
5871

59-
output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5")
72+
if tif:
73+
if output_key is None:
74+
output_name = basename + "_crop_" + coord_string + "_" + image_prefix + ".tif"
75+
else:
76+
output_name = basename + "_" + image_prefix + "_crop_" + coord_string + "_" + output_key + ".tif"
77+
78+
output_file = os.path.join(output_dir, output_name)
79+
else:
80+
output_key = "raw" if output_key is None else output_key
81+
output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5")
6082

6183
coords = np.array(coords)
6284
coords = coords / resolution
@@ -75,26 +97,31 @@ def main(
7597
with zarr.open(input_path, mode="r") as f:
7698
raw = f[input_key][roi]
7799

78-
with zarr.open(output_file, mode="w") as f_out:
79-
f_out.create_dataset("raw", data=raw, compression="gzip")
100+
if tif:
101+
imageio.imwrite(output_file, raw, compression="zlib")
102+
else:
103+
with zarr.open(output_file, mode="w") as f_out:
104+
f_out.create_dataset(output_key, data=raw, compression="gzip")
80105

81106

82107
if __name__ == "__main__":
83108

84109
parser = argparse.ArgumentParser(
85110
description="Script to extract region of interest (ROI) block around center coordinate.")
86111

87-
parser.add_argument('-i', '--input', type=str, help="Input file in n5 / ome-zarr format.")
112+
parser.add_argument('-i', '--input', type=str, required=True, help="Input file in n5 / ome-zarr format.")
88113
parser.add_argument('-o', "--output", type=str, default="", help="Output directory.")
89114
parser.add_argument('-c', "--coord", type=str, required=True,
90115
help="3D coordinate as center of extracted block, json-encoded.")
91116

92117
parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0",
93118
help="Input key for data in input file.")
119+
parser.add_argument("--output_key", type=str, default=None,
120+
help="Output key for data in output file.")
94121
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.")
95-
96122
parser.add_argument("--roi_halo", type=str, default="[128,128,64]",
97123
help="ROI halo around center coordinate, json-encoded.")
124+
parser.add_argument("--tif", action="store_true", help="Store output as tif file.")
98125

99126
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
100127
parser.add_argument("--s3_credentials", type=str, default=None,
@@ -111,6 +138,6 @@ def main(
111138
roi_halo = json.loads(args.roi_halo)
112139

113140
main(
114-
args.input, coords, args.output, args.input_key, args.resolution, roi_halo,
115-
args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint,
141+
args.input, coords, args.output, args.input_key, args.output_key, args.resolution, roi_halo,
142+
args.tif, args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint,
116143
)

0 commit comments

Comments
 (0)