Skip to content

Commit 2ae7bd1

Browse files
committed
Apply threshold on marker stain
1 parent 88af4e3 commit 2ae7bd1

File tree

2 files changed

+225
-53
lines changed

2 files changed

+225
-53
lines changed

flamingo_tools/segmentation/chreef_utils.py

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import math
32
import multiprocessing as mp
43
from concurrent import futures
54
from typing import List, Tuple
@@ -9,45 +8,45 @@
98
from tqdm import tqdm
109

1110

12-
def find_annotations(annotation_dir, cochleae=None) -> dict:
11+
def coord_from_string(center_str):
12+
return tuple([int(c) for c in center_str.split("-")])
13+
14+
15+
def find_annotations(annotation_dir, cochlea) -> dict:
1316
"""Create dictionary for analysis of ChReef annotations.
1417
Annotations should have format positive-negative_<cochlea>_crop_<coord>_allNegativeExcluded_thr<thr>.tif
1518
1619
Args:
1720
annotation_dir: Directory containing annotations.
1821
"""
1922

20-
def extract_center_crop(cochlea, name):
23+
def extract_center_string(cochlea, name):
2124
# Extract center crop coordinate from file name
2225
crop_suffix = name.split(f"{cochlea}_crop_")[1]
23-
coord_str = crop_suffix.split("_")[0]
24-
coord = tuple([int(c) for c in coord_str.split("-")])
25-
return coord
26-
27-
def extract_cochlea_str(name):
28-
# Extract cochlea str from annotation file name.
29-
cochlea_suffix = name.split("negative_")[1]
30-
cochlea = cochlea_suffix.split("_crop")[0]
31-
return cochlea
32-
33-
file_names = [entry.name for entry in os.scandir(annotation_dir)]
34-
if cochleae is None:
35-
cochleae = list(set([extract_cochlea_str(file_name) for file_name in file_names]))
36-
37-
annotation_dic = {}
38-
for cochlea in cochleae:
39-
cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name]
40-
dic = {"cochlea": cochlea}
41-
dic["cochlea_files"] = cochlea_files
42-
center_crops = list(set([extract_center_crop(cochlea, name=file_name) for file_name in cochlea_files]))
43-
dic["center_coords"] = center_crops
44-
dic["center_str"] = [("-").join([str(c).zfill(4) for center_crop in center_crops for c in center_crop])]
45-
for center_str in dic["center_str"]:
46-
file_neg = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "NegativeExcluded"])][0]
47-
file_pos = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "WeakPositive"])][0]
48-
dic[center_str] = {"file_neg": file_neg, "file_pos": file_pos}
49-
annotation_dic[cochlea] = dic
50-
return annotation_dic
26+
center_str = crop_suffix.split("_")[0]
27+
return center_str
28+
29+
cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name]
30+
dic = {"cochlea": cochlea}
31+
dic["cochlea_files"] = cochlea_files
32+
center_strings = list(set([extract_center_string(cochlea, name=f) for f in cochlea_files]))
33+
center_strings.sort()
34+
dic["center_strings"] = center_strings
35+
remove_strings = []
36+
for center_str in center_strings:
37+
files_neg = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "NegativeExcluded"])]
38+
files_pos = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "WeakPositive"])]
39+
if len(files_neg) != 1 or len(files_pos) != 1:
40+
print(f"Skipping crop {center_str} for cochlea {cochlea}. "
41+
f"Missing or multiple annotation files in {annotation_dir}.")
42+
remove_strings.append(center_str)
43+
else:
44+
dic[center_str] = {"file_neg": os.path.join(annotation_dir, files_neg[0]),
45+
"file_pos": os.path.join(annotation_dir, files_pos[0])}
46+
for rm_str in remove_strings:
47+
dic["center_strings"].remove(rm_str)
48+
49+
return dic
5150

5251

5352
def get_roi(coord: tuple, roi_halo: tuple, resolution: float = 0.38) -> Tuple[int]:
@@ -106,7 +105,7 @@ def check_overlap(ref_id):
106105
return None
107106

108107
n_threads = min(16, mp.cpu_count())
109-
print(f"Parallelizing with {n_threads} Threads.")
108+
print(f"Finding overlapping masks with {n_threads} Threads.")
110109
with futures.ThreadPoolExecutor(n_threads) as pool:
111110
results = list(tqdm(pool.map(check_overlap, ref_ids), total=len(ref_ids)))
112111

@@ -129,7 +128,7 @@ def find_inbetween_ids(
129128
# negative annotation == 1, positive annotation == 2
130129
negexc_negatives = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=1)
131130
allweak_positives = find_overlapping_masks(arr_allweak, roi_seg, label_id_base=2)
132-
inbetween_ids = list(set(negexc_negatives) & set(allweak_positives))
131+
inbetween_ids = [int(i) for i in set(negexc_negatives).intersection(set(allweak_positives))]
133132
return inbetween_ids
134133

135134

@@ -142,26 +141,24 @@ def get_median_intensity(file_negexc, file_allweak, center, data_seg, table):
142141

143142
roi_seg = data_seg[roi]
144143
inbetween_ids = find_inbetween_ids(arr_negexc, arr_allweak, roi_seg)
145-
intensities = table.loc[table["label_id"].isin(inbetween_ids), table["mean"]]
144+
subset = table[table["label_id"].isin(inbetween_ids)]
145+
intensities = list(subset["median"])
146146
return np.median(list(intensities))
147147

148148

149-
def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure, table_block=None):
150-
annotation_dic = find_annotations(annotation_dir, cochleae=[cochlea])
151-
for key in annotation_dic.keys():
152-
dic = annotation_dic[key]
153-
for center_coord, center_str in zip(dic["center_coords"], dic["center_str"]):
154-
file_pos = dic[center_str["file_pos"]]
155-
file_neg = dic[center_str["file_neg"]]
156-
median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg, table_measure)
157-
158-
annotation_dic[key][center_str]["median_intensity"] = median_intensity
159-
if table_block is not None:
160-
block_centers = table_block["crop_centers"]
161-
for num, block_center in enumerate(block_centers):
162-
dist = math.dist(tuple(block_centers), center_coord)
163-
if dist < 5:
164-
annotation_dic[key][center_str]["block_index"] = num
165-
annotation_dic[key][center_str]["block_center"] = block_center
166-
167-
return annotation_dic[cochlea]
149+
def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure):
150+
"""Find median intensities in blocks and assign them to center positions of cropped block.
151+
"""
152+
annotation_dic = find_annotations(annotation_dir, cochlea)
153+
# center_keys = [key for key in annotation_dic["center_strings"] if key in annotation_dic.keys()]
154+
155+
for center_str in annotation_dic["center_strings"]:
156+
center_coord = coord_from_string(center_str)
157+
print(f"Getting mean intensities for {center_coord}.")
158+
file_pos = annotation_dic[center_str]["file_pos"]
159+
file_neg = annotation_dic[center_str]["file_neg"]
160+
median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg, table_measure)
161+
162+
annotation_dic[center_str]["median_intensity"] = median_intensity
163+
164+
return annotation_dic
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import argparse
2+
import os
3+
from typing import List, Optional
4+
5+
import pandas as pd
6+
7+
from flamingo_tools.s3_utils import get_s3_path
8+
from flamingo_tools.file_utils import read_image_data
9+
from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations
10+
11+
MARKER_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ChReef_PV-GFP/2025-07_PV_GFP_SGN"
12+
13+
14+
def get_length_fraction_from_center(table, center_str):
15+
""" Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances.
16+
"""
17+
center_coord = tuple([int(c) for c in center_str.split("-")])
18+
(cx, cy, cz) = center_coord
19+
offset = 20
20+
subset = table[
21+
(cx - offset < table["anchor_x"]) &
22+
(table["anchor_x"] < cx + offset) &
23+
(cy - offset < table["anchor_y"]) &
24+
(table["anchor_y"] < cy + offset) &
25+
(cz - offset < table["anchor_z"]) &
26+
(table["anchor_z"] < cz + offset)
27+
]
28+
length_fraction = list(subset["length_fraction"])
29+
length_fraction = float(sum(length_fraction) / len(length_fraction))
30+
return length_fraction
31+
32+
33+
def apply_nearest_threshold(intensity_dic, table_seg, table_measurement):
34+
"""Apply threshold to nearest segmentation instances.
35+
Crop centers are transformed into the 'length fraction' parameter of the segmentation table.
36+
This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal's canal.
37+
"""
38+
# assign crop centers to length fraction of Rosenthal's canal
39+
lf_intensity = {}
40+
for key in intensity_dic.keys():
41+
length_fraction = get_length_fraction_from_center(table_seg, key)
42+
intensity_dic[key]["length_fraction"] = length_fraction
43+
lf_intensity[length_fraction] = {"threshold": intensity_dic[key]["median_intensity"]}
44+
45+
# get limits for checking marker thresholds
46+
lf_intensity = dict(sorted(lf_intensity.items()))
47+
lf_fractions = list(lf_intensity.keys())
48+
# start of cochlea
49+
lf_limits = [0]
50+
# half distance between block centers
51+
for i in range(len(lf_fractions) - 1):
52+
lf_limits.append((lf_fractions[i] + lf_fractions[i+1]) / 2)
53+
# end of cochlea
54+
lf_limits.append(1)
55+
56+
marker_labels = [0 for _ in range(len(table_seg))]
57+
table_seg.loc[:, "marker_labels"] = marker_labels
58+
for num, fraction in enumerate(lf_fractions):
59+
subset_seg = table_seg[
60+
(table_seg["length_fraction"] > lf_limits[num]) &
61+
(table_seg["length_fraction"] < lf_limits[num + 1])
62+
]
63+
# assign values based on limits
64+
threshold = lf_intensity[fraction]["threshold"]
65+
label_ids_seg = subset_seg["label_id"]
66+
67+
subset_measurement = table_measurement[table_measurement["label_id"].isin(label_ids_seg)]
68+
subset_positive = subset_measurement[subset_measurement["median"] >= threshold]
69+
subset_negative = subset_measurement[subset_measurement["median"] < threshold]
70+
label_ids_pos = list(subset_positive["label_id"])
71+
label_ids_neg = list(subset_negative["label_id"])
72+
73+
table_seg.loc[table_seg["label_id"].isin(label_ids_pos), "marker_labels"] = 1
74+
table_seg.loc[table_seg["label_id"].isin(label_ids_neg), "marker_labels"] = 2
75+
76+
return table_seg
77+
78+
79+
def evaluate_marker_annotation(
80+
cochleae,
81+
output_dir: str,
82+
annotation_dirs: Optional[List[str]] = None,
83+
seg_name: str = "SGN_v2",
84+
marker_name: str = "GFP",
85+
):
86+
"""Evaluate marker annotations of a single or multiple annotators.
87+
Segmentation instances are assigned a positive (1) or negative label (2)
88+
in form of the "marker_label" component of the output segmentation table.
89+
The assignment is based on the median intensity supplied by a measurement table.
90+
Instances not considered for the assignment are labeled as 0.
91+
92+
Args:
93+
cochleae: List of cochlea
94+
output_dir: Output directory for segmentation table with 'marker_label' in format <cochlea>_<marker>_<seg>.tsv
95+
annotation_dirs: List of directories containing marker annotations by annotator(s).
96+
seg_name: Identifier for segmentation.
97+
marker_name: Identifier for marker stain.
98+
"""
99+
input_key = "s0"
100+
101+
if annotation_dirs is None:
102+
if "MARKER_DIR" in globals():
103+
marker_dir = MARKER_DIR
104+
annotation_dirs = [entry.path for entry in os.scandir(marker_dir)
105+
if os.path.isdir(entry) and "Results" in entry.name]
106+
107+
for cochlea in cochleae:
108+
cochlea_annotations = [a for a in annotation_dirs if len(find_annotations(a, cochlea)["center_strings"]) != 0]
109+
print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.")
110+
111+
# get segmentation data
112+
input_path = f"{cochlea}/images/ome-zarr/{seg_name}.ome.zarr"
113+
input_path, fs = get_s3_path(input_path)
114+
data_seg = read_image_data(input_path, input_key)
115+
116+
table_seg_path = f"{cochlea}/tables/{seg_name}/default.tsv"
117+
table_path_s3, fs = get_s3_path(table_seg_path)
118+
with fs.open(table_path_s3, "r") as f:
119+
table_seg = pd.read_csv(f, sep="\t")
120+
121+
seg_string = "-".join(seg_name.split("_"))
122+
table_measurement_path = f"{cochlea}/tables/{seg_name}/{marker_name}_{seg_string}_object-measures.tsv"
123+
table_path_s3, fs = get_s3_path(table_measurement_path)
124+
with fs.open(table_path_s3, "r") as f:
125+
table_measurement = pd.read_csv(f, sep="\t")
126+
127+
# find median intensities by averaging all individual annotations for specific crops
128+
annotation_dics = {}
129+
annotated_centers = []
130+
for annotation_dir in cochlea_annotations:
131+
132+
annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg, table_measurement)
133+
annotated_centers.extend(annotation_dic["center_strings"])
134+
annotation_dics[annotation_dir] = annotation_dic
135+
136+
annotated_centers = list(set(annotated_centers))
137+
intensity_dic = {}
138+
# loop over all annotated blocks
139+
for annotated_center in annotated_centers:
140+
intensities = []
141+
# loop over annotated block from single user
142+
for annotator_key in annotation_dics.keys():
143+
if annotated_center not in annotation_dics[annotator_key]["center_strings"]:
144+
continue
145+
else:
146+
intensities.append(annotation_dics[annotator_key][annotated_center]["median_intensity"])
147+
intensity_dic[annotated_center] = {"median_intensity": float(sum(intensities) / len(intensities))}
148+
149+
table_seg = apply_nearest_threshold(intensity_dic, table_seg, table_measurement)
150+
cochlea_str = "-".join(cochlea.split("_"))
151+
out_path = os.path.join(output_dir, f"{cochlea_str}_{marker_name}_{seg_string}.tsv")
152+
table_seg.to_csv(out_path, sep="\t", index=False)
153+
154+
155+
def main():
156+
parser = argparse.ArgumentParser(
157+
description="Assign each segmentation instance a marker based on annotation thresholds.")
158+
159+
parser.add_argument('-c', "--cochlea", type=str, nargs="+", required=True,
160+
help="Cochlea(e) to process.")
161+
parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.")
162+
163+
parser.add_argument('-a', '--annotation_dirs', type=str, nargs="+", default=None,
164+
help="Directories containing marker annotations.")
165+
166+
args = parser.parse_args()
167+
168+
evaluate_marker_annotation(
169+
args.cochlea, args.output, args.annotation_dirs,
170+
)
171+
172+
173+
if __name__ == "__main__":
174+
175+
main()

0 commit comments

Comments
 (0)