Skip to content

Commit a3bd929

Browse files
committed
Evaluate SGN subtype thresholds
1 parent 29cf122 commit a3bd929

File tree

2 files changed

+258
-9
lines changed

2 files changed

+258
-9
lines changed

flamingo_tools/segmentation/chreef_utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def coord_from_string(center_str):
1212
return tuple([int(c) for c in center_str.split("-")])
1313

1414

15-
def find_annotations(annotation_dir: str, cochlea: str) -> dict:
15+
def find_annotations(annotation_dir: str, cochlea: str, pattern: str = None) -> dict:
1616
"""Create a dictionary for the analysis of ChReef annotations.
1717
1818
Annotations should have format positive-negative_<cochlea>_crop_<coord>_allNegativeExcluded_thr<thr>.tif
@@ -31,7 +31,11 @@ def extract_center_string(cochlea, name):
3131
center_str = crop_suffix.split("_")[0]
3232
return center_str
3333

34-
cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name]
34+
if pattern is not None:
35+
cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name
36+
and pattern in entry.name]
37+
else:
38+
cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name]
3539
dic = {"cochlea": cochlea}
3640
dic["cochlea_files"] = cochlea_files
3741
center_strings = list(set([extract_center_string(cochlea, name=f) for f in cochlea_files]))
@@ -140,7 +144,7 @@ def find_inbetween_ids(
140144
return inbetween_ids, allweak_positives, negexc_negatives
141145

142146

143-
def get_median_intensity(file_negexc, file_allweak, center, data_seg, table):
147+
def get_median_intensity(file_negexc, file_allweak, center, data_seg, table, column="median"):
144148
arr_negexc = tifffile.imread(file_negexc)
145149
arr_allweak = tifffile.imread(file_allweak)
146150

@@ -155,31 +159,32 @@ def get_median_intensity(file_negexc, file_allweak, center, data_seg, table):
155159

156160
subset_positive = table[table["label_id"].isin(allweak_positives)]
157161
subset_negative = table[table["label_id"].isin(negexc_negatives)]
158-
lowest_positive = float(subset_positive["median"].min())
159-
highest_negative = float(subset_negative["median"].max())
162+
lowest_positive = float(subset_positive[column].min())
163+
highest_negative = float(subset_negative[column].max())
160164
if np.isnan(lowest_positive) or np.isnan(highest_negative):
161165
return None
162166

163167
return np.average([lowest_positive, highest_negative])
164168

165169
subset = table[table["label_id"].isin(inbetween_ids)]
166-
intensities = list(subset["median"])
170+
intensities = list(subset[column])
167171

168172
return np.median(list(intensities))
169173

170174

171-
def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure):
175+
def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure, column="median", pattern=None):
172176
"""Find median intensities in blocks and assign them to center positions of cropped block.
173177
"""
174-
annotation_dic = find_annotations(annotation_dir, cochlea)
178+
annotation_dic = find_annotations(annotation_dir, cochlea, pattern=pattern)
175179
# center_keys = [key for key in annotation_dic["center_strings"] if key in annotation_dic.keys()]
176180

177181
for center_str in annotation_dic["center_strings"]:
178182
center_coord = coord_from_string(center_str)
179183
print(f"Getting median intensities for {center_coord}.")
180184
file_pos = annotation_dic[center_str]["file_pos"]
181185
file_neg = annotation_dic[center_str]["file_neg"]
182-
median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg, table_measure)
186+
median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg,
187+
table_measure, column=column)
183188

184189
if median_intensity is None:
185190
print(f"No threshold identified for {center_str}.")
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import argparse
2+
import json
3+
import os
4+
from typing import List, Optional
5+
6+
import pandas as pd
7+
8+
from flamingo_tools.s3_utils import get_s3_path
9+
from flamingo_tools.file_utils import read_image_data
10+
from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations
11+
12+
MARKER_DIR_SUBTYPE = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes"
13+
# The cochlea for the CHReef analysis.
14+
15+
COCHLEAE = {
16+
"M_LR_000184_L": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},
17+
"M_LR_000184_R": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},
18+
"M_LR_000099_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1", "Lypd1"]},
19+
"M_LR_000214_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1"]},
20+
}
21+
22+
23+
def get_length_fraction_from_center(table, center_str):
24+
"""Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances.
25+
"""
26+
center_coord = tuple([int(c) for c in center_str.split("-")])
27+
(cx, cy, cz) = center_coord
28+
offset = 20
29+
subset = table[
30+
(cx - offset < table["anchor_x"]) &
31+
(table["anchor_x"] < cx + offset) &
32+
(cy - offset < table["anchor_y"]) &
33+
(table["anchor_y"] < cy + offset) &
34+
(cz - offset < table["anchor_z"]) &
35+
(table["anchor_z"] < cz + offset)
36+
]
37+
length_fraction = list(subset["length_fraction"])
38+
length_fraction = float(sum(length_fraction) / len(length_fraction))
39+
return length_fraction
40+
41+
42+
def apply_nearest_threshold(intensity_dic, table_seg, table_measurement, column="median", suffix="labels"):
43+
"""Apply threshold to nearest segmentation instances.
44+
Crop centers are transformed into the "length fraction" parameter of the segmentation table.
45+
This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal"s canal.
46+
"""
47+
# assign crop centers to length fraction of Rosenthal"s canal
48+
lf_intensity = {}
49+
for key in intensity_dic.keys():
50+
length_fraction = get_length_fraction_from_center(table_seg, key)
51+
intensity_dic[key]["length_fraction"] = length_fraction
52+
lf_intensity[length_fraction] = {"threshold": intensity_dic[key]["median_intensity"]}
53+
54+
# get limits for checking marker thresholds
55+
lf_intensity = dict(sorted(lf_intensity.items()))
56+
lf_fractions = list(lf_intensity.keys())
57+
# start of cochlea
58+
lf_limits = [0]
59+
# half distance between block centers
60+
for i in range(len(lf_fractions) - 1):
61+
lf_limits.append((lf_fractions[i] + lf_fractions[i+1]) / 2)
62+
# end of cochlea
63+
lf_limits.append(1)
64+
65+
marker_labels = [0 for _ in range(len(table_seg))]
66+
table_seg.loc[:, f"marker_{suffix}"] = marker_labels
67+
for num, fraction in enumerate(lf_fractions):
68+
subset_seg = table_seg[
69+
(table_seg["length_fraction"] > lf_limits[num]) &
70+
(table_seg["length_fraction"] < lf_limits[num + 1])
71+
]
72+
# assign values based on limits
73+
threshold = lf_intensity[fraction]["threshold"]
74+
label_ids_seg = subset_seg["label_id"]
75+
76+
subset_measurement = table_measurement[table_measurement["label_id"].isin(label_ids_seg)]
77+
subset_positive = subset_measurement[subset_measurement[column] >= threshold]
78+
subset_negative = subset_measurement[subset_measurement[column] < threshold]
79+
label_ids_pos = list(subset_positive["label_id"])
80+
label_ids_neg = list(subset_negative["label_id"])
81+
82+
table_seg.loc[table_seg["label_id"].isin(label_ids_pos), f"marker_{suffix}"] = 1
83+
table_seg.loc[table_seg["label_id"].isin(label_ids_neg), f"marker_{suffix}"] = 2
84+
85+
return table_seg
86+
87+
88+
def find_thresholds(cochlea_annotations, cochlea, data_seg, table_measurement, column="median", pattern=None):
89+
# Find the median intensities by averaging the individual annotations for specific crops
90+
annotation_dics = {}
91+
annotated_centers = []
92+
for annotation_dir in cochlea_annotations:
93+
print(f"Localizing threshold with median intensities for {os.path.basename(annotation_dir)}.")
94+
annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg,
95+
table_measurement, column=column, pattern=pattern)
96+
annotated_centers.extend(annotation_dic["center_strings"])
97+
annotation_dics[annotation_dir] = annotation_dic
98+
99+
annotated_centers = list(set(annotated_centers))
100+
intensity_dic = {}
101+
# loop over all annotated blocks
102+
for annotated_center in annotated_centers:
103+
intensities = []
104+
annotator_success = []
105+
annotator_failure = []
106+
annotator_missing = []
107+
# loop over annotated block from single user
108+
for annotator_key in annotation_dics.keys():
109+
if annotated_center not in annotation_dics[annotator_key]["center_strings"]:
110+
annotator_missing.append(os.path.basename(annotator_key))
111+
continue
112+
else:
113+
median_intensity = annotation_dics[annotator_key][annotated_center]["median_intensity"]
114+
if median_intensity is None:
115+
print(f"No threshold for {os.path.basename(annotator_key)} and crop {annotated_center}.")
116+
annotator_failure.append(os.path.basename(annotator_key))
117+
else:
118+
intensities.append(median_intensity)
119+
annotator_success.append(os.path.basename(annotator_key))
120+
121+
if len(intensities) == 0:
122+
print(f"No viable annotation for cochlea {cochlea} and crop {annotated_center}.")
123+
median_int_avg = None
124+
else:
125+
median_int_avg = float(sum(intensities) / len(intensities)),
126+
127+
intensity_dic[annotated_center] = {
128+
"median_intensity": median_int_avg,
129+
"annotation_success": annotator_success,
130+
"annotation_failure": annotator_failure,
131+
"annotation_missing": annotator_missing,
132+
}
133+
134+
return intensity_dic
135+
136+
137+
def evaluate_marker_annotation(
138+
cochleae: List[str],
139+
output_dir: str,
140+
annotation_dirs: Optional[List[str]] = None,
141+
seg_name: str = "SGN_v2",
142+
marker_name: str = "Calb1",
143+
threshold_save_dir: Optional[str] = None,
144+
force: bool = False,
145+
) -> None:
146+
"""Evaluate marker annotations of a single or multiple annotators.
147+
Segmentation instances are assigned a positive (1) or negative label (2)
148+
in form of the "marker_label" component of the output segmentation table.
149+
The assignment is based on the median intensity supplied by a measurement table.
150+
Instances not considered for the assignment are labeled as 0.
151+
152+
Args:
153+
cochleae: List of cochlea
154+
output_dir: Output directory for segmentation table with "marker_label" in format <cochlea>_<marker>_<seg>.tsv
155+
annotation_dirs: List of directories containing marker annotations by annotator(s).
156+
seg_name: Identifier for segmentation.
157+
marker_name: Identifier for marker stain.
158+
threshold_save_dir: Optional directory for saving the thresholds.
159+
force: Whether to overwrite already existing results.
160+
"""
161+
input_key = "s0"
162+
163+
if annotation_dirs is None:
164+
marker_dir = MARKER_DIR_SUBTYPE
165+
annotation_dirs = [entry.path for entry in os.scandir(marker_dir)
166+
if os.path.isdir(entry) and "Result" in entry.name]
167+
168+
for cochlea in cochleae:
169+
data_name = COCHLEAE[cochlea]["seg_data"]
170+
if "output_seg" in list(COCHLEAE[cochlea].keys()):
171+
output_seg = COCHLEAE[cochlea]["output_seg"]
172+
else:
173+
output_seg = data_name
174+
175+
seg_string = "-".join(output_seg.split("_"))
176+
cochlea_str = "-".join(cochlea.split("_"))
177+
subtypes = COCHLEAE[cochlea]["subtype"]
178+
subtype_str = "_".join(subtypes)
179+
out_path = os.path.join(output_dir, f"{cochlea_str}_{subtype_str}_{seg_string}.tsv")
180+
if os.path.exists(out_path) and not force:
181+
continue
182+
183+
# Get the segmentation data and table.
184+
input_path = f"{cochlea}/images/ome-zarr/{data_name}.ome.zarr"
185+
input_path, fs = get_s3_path(input_path)
186+
data_seg = read_image_data(input_path, input_key)
187+
188+
table_seg_path = f"{cochlea}/tables/{output_seg}/default.tsv"
189+
table_path_s3, fs = get_s3_path(table_seg_path)
190+
with fs.open(table_path_s3, "r") as f:
191+
table_seg = pd.read_csv(f, sep="\t")
192+
193+
table_measurement_path = f"{cochlea}/tables/{data_name}/subtype_ratio.tsv"
194+
195+
# iterate through subtypes
196+
for subtype in subtypes:
197+
column = f"{subtype}_ratio_PV"
198+
table_path_s3, fs = get_s3_path(table_measurement_path)
199+
with fs.open(table_path_s3, "r") as f:
200+
table_measurement = pd.read_csv(f, sep="\t")
201+
202+
cochlea_annotations = [a for a in annotation_dirs
203+
if len(find_annotations(a, cochlea, subtype)["center_strings"]) != 0]
204+
print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.")
205+
206+
# Find the threholds from the annotated blocks and save it if specified.
207+
intensity_dic = find_thresholds(cochlea_annotations, cochlea, data_seg,
208+
table_measurement, column=column, pattern=subtype)
209+
if threshold_save_dir is not None:
210+
os.makedirs(threshold_save_dir, exist_ok=True)
211+
threshold_out_path = os.path.join(threshold_save_dir, f"{cochlea_str}_{subtype}_{seg_string}.json")
212+
with open(threshold_out_path, "w") as f:
213+
json.dump(intensity_dic, f, sort_keys=True, indent=4)
214+
215+
# Apply the threshold to all SGNs.
216+
table_seg = apply_nearest_threshold(
217+
intensity_dic, table_seg, table_measurement, column=column, suffix=subtype,
218+
)
219+
220+
# Save the table with positives / negatives for all SGNs.
221+
os.makedirs(output_dir, exist_ok=True)
222+
table_seg.to_csv(out_path, sep="\t", index=False)
223+
224+
225+
def main():
226+
parser = argparse.ArgumentParser(
227+
description="Assign each segmentation instance a marker based on annotation thresholds."
228+
)
229+
230+
parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=COCHLEAE, help="Cochlea(e) to process.")
231+
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.")
232+
parser.add_argument("-a", "--annotation_dirs", type=str, nargs="+", default=None,
233+
help="Directories containing marker annotations.")
234+
parser.add_argument("--threshold_save_dir", "-t")
235+
parser.add_argument("-f", "--force", action="store_true")
236+
237+
args = parser.parse_args()
238+
evaluate_marker_annotation(
239+
args.cochlea, args.output, args.annotation_dirs, threshold_save_dir=args.threshold_save_dir, force=args.force,
240+
)
241+
242+
243+
if __name__ == "__main__":
244+
main()

0 commit comments

Comments
 (0)