Skip to content

Commit d742c76

Browse files
committed
Export table with annotation parameters
1 parent 7e54709 commit d742c76

File tree

2 files changed

+89
-22
lines changed

2 files changed

+89
-22
lines changed

flamingo_tools/segmentation/chreef_utils.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def check_overlap(ref_id):
121121
with futures.ThreadPoolExecutor(n_threads) as pool:
122122
results = list(tqdm(pool.map(check_overlap, ref_ids), total=len(ref_ids)))
123123

124-
matching_ids = {r for r in results if r is not None}
124+
matching_ids = [r for r in results if r is not None]
125125
return matching_ids
126126

127127

@@ -141,39 +141,64 @@ def find_inbetween_ids(
141141
A list of the ids that are in between the respective thresholds.
142142
"""
143143
# negative annotation == 1, positive annotation == 2
144-
negexc_negatives = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=1)
145-
allweak_positives = find_overlapping_masks(arr_allweak, roi_seg, label_id_base=2)
146-
inbetween_ids = [int(i) for i in set(negexc_negatives).intersection(set(allweak_positives))]
147-
return inbetween_ids, allweak_positives, negexc_negatives
144+
negexc_neg = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=1)
145+
allweak_pos = find_overlapping_masks(arr_allweak, roi_seg, label_id_base=2)
148146

147+
negexc_pos = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=2)
148+
allweak_neg = find_overlapping_masks(arr_allweak, roi_seg, label_id_base=1)
149+
inbetween_ids = [int(i) for i in set(negexc_neg).intersection(set(allweak_pos))]
150+
return inbetween_ids, allweak_pos, negexc_neg, allweak_neg, negexc_pos
149151

150-
def get_median_intensity(file_negexc, file_allweak, center, data_seg, table, column="median",
152+
153+
def get_crop_parameters(file_negexc, file_allweak, center, data_seg, table, column="median",
151154
resolution=0.38):
152155
arr_negexc = tifffile.imread(file_negexc)
153156
arr_allweak = tifffile.imread(file_allweak)
157+
param_dic = {}
154158

155159
roi_halo = tuple([r // 2 for r in arr_negexc.shape])
156160
roi = get_roi(center, roi_halo, resolution=resolution)
157161

158162
roi_seg = data_seg[roi]
159-
inbetween_ids, allweak_positives, negexc_negatives = find_inbetween_ids(arr_negexc, arr_allweak, roi_seg)
163+
inbetween_ids, allweak_pos, negexc_neg, allweak_neg, negexc_pos = find_inbetween_ids(arr_negexc,
164+
arr_allweak, roi_seg)
165+
166+
param_dic["inbetween_ids"] = inbetween_ids
167+
param_dic["allweak_pos"] = allweak_pos
168+
param_dic["allweak_neg"] = allweak_neg
169+
param_dic["negexc_neg"] = negexc_neg
170+
param_dic["negexc_pos"] = negexc_pos
171+
172+
subset_allweak_pos = table[table["label_id"].isin(allweak_pos)]
173+
subset_allweak_neg = table[table["label_id"].isin(allweak_neg)]
174+
subset_negexc_neg = table[table["label_id"].isin(negexc_neg)]
175+
subset_negexc_pos = table[table["label_id"].isin(negexc_pos)]
176+
param_dic["allweak_pos_mean"] = float(subset_allweak_pos[column].mean())
177+
param_dic["allweak_neg_mean"] = float(subset_allweak_neg[column].mean())
178+
param_dic["negexc_neg_mean"] = float(subset_negexc_neg[column].mean())
179+
param_dic["negexc_pos_mean"] = float(subset_negexc_pos[column].mean())
180+
160181
if len(inbetween_ids) == 0:
161-
if len(allweak_positives) == 0 and len(negexc_negatives) == 0:
162-
return None
182+
if len(allweak_pos) == 0 and len(negexc_neg) == 0:
183+
param_dic["median_intensity"] = None
184+
return param_dic
163185

164-
subset_positive = table[table["label_id"].isin(allweak_positives)]
165-
subset_negative = table[table["label_id"].isin(negexc_negatives)]
186+
subset_positive = table[table["label_id"].isin(allweak_pos)]
187+
subset_negative = table[table["label_id"].isin(negexc_neg)]
166188
lowest_positive = float(subset_positive[column].min())
167189
highest_negative = float(subset_negative[column].max())
168190
if np.isnan(lowest_positive) or np.isnan(highest_negative):
169-
return None
191+
param_dic["median_intensity"] = None
192+
return param_dic
170193

171-
return np.average([lowest_positive, highest_negative])
194+
param_dic["median_intensity"] = np.average([lowest_positive, highest_negative])
195+
return param_dic
172196

173197
subset = table[table["label_id"].isin(inbetween_ids)]
174198
intensities = list(subset[column])
199+
param_dic["median_intensity"] = np.median(list(intensities))
175200

176-
return np.median(list(intensities))
201+
return param_dic
177202

178203

179204
def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure, column="median", pattern=None,
@@ -188,12 +213,14 @@ def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure
188213
print(f"Getting median intensities for {center_coord}.")
189214
file_pos = annotation_dic[center_str]["file_pos"]
190215
file_neg = annotation_dic[center_str]["file_neg"]
191-
median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg,
216+
param_dic = get_crop_parameters(file_neg, file_pos, center_coord, data_seg,
192217
table_measure, column=column, resolution=resolution)
193218

219+
median_intensity = param_dic["median_intensity"]
194220
if median_intensity is None:
195221
print(f"No threshold identified for {center_str}.")
196222

197-
annotation_dic[center_str]["median_intensity"] = median_intensity
223+
for key in param_dic.keys():
224+
annotation_dic[center_str][key] = param_dic[key]
198225

199226
return annotation_dic

scripts/measurements/evaluate_marker_annotations_subtype.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# The cochlea for the CHReef analysis.
1414

1515
COCHLEAE = {
16+
"M_LR_000099_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1", "Lypd1"], "intensity": "ratio"},
1617
"M_AMD_N180_L": {"seg_data": "SGN_merged", "subtype": ["CR", "Lypd1", "Ntng1"], "intensity": "absolute"},
1718
"M_AMD_N180_R": {"seg_data": "SGN_merged", "subtype": ["CR", "Ntng1"], "intensity": "absolute"},
18-
"M_LR_000099_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1", "Lypd1"], "intensity": "ratio"},
19+
"M_LR_000098_L": {"seg_data": "SGN_v2", "subtype": ["CR", "Ntng1"], "intensity": "ratio"},
1920
"M_LR_000184_L": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b", "intensity": "ratio"},
2021
"M_LR_000184_R": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b", "intensity": "ratio"},
2122
"M_LR_000214_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1"], "intensity": "ratio"},
2223
"M_LR_000260_L": {"seg_data": "SGN_v2", "subtype": ["Prph", "Tuj1"], "intensity": "ratio"},
23-
24+
"M_LR_N152_L": {"seg_data": "SGN_v2", "subtype": ["CR", "Ntng1"], "intensity": "ratio"},
2425
}
2526

2627

@@ -135,7 +136,34 @@ def find_thresholds(cochlea_annotations, cochlea, data_seg, table_measurement, c
135136
"annotation_missing": annotator_missing,
136137
}
137138

138-
return intensity_dic
139+
return intensity_dic, annotation_dics
140+
141+
142+
def get_annotation_table(annotation_dics, subtype):
143+
rows = []
144+
for annotation_dir, annotation_dic in annotation_dics.items():
145+
146+
annotator_dir = os.path.basename(annotation_dir)
147+
annotator = annotator_dir.split("_")[1]
148+
for center_str in annotation_dic["center_strings"]:
149+
row = {"annotator" : annotator}
150+
row["subtype"] = subtype
151+
row["center_str"] = center_str
152+
row["median_intensity"] = annotation_dic[center_str]["median_intensity"]
153+
row["inbetween_ids"] = len(annotation_dic[center_str]["inbetween_ids"])
154+
row["allweak_pos"] = len(annotation_dic[center_str]["allweak_pos"])
155+
row["allweak_neg"] = len(annotation_dic[center_str]["allweak_neg"])
156+
row["negexc_pos"] = len(annotation_dic[center_str]["negexc_pos"])
157+
row["negexc_neg"] = len(annotation_dic[center_str]["negexc_neg"])
158+
159+
row["allweak_pos_mean"] = annotation_dic[center_str]["allweak_pos_mean"]
160+
row["allweak_neg_mean"] = annotation_dic[center_str]["allweak_neg_mean"]
161+
row["negexc_pos_mean"] = annotation_dic[center_str]["negexc_pos_mean"]
162+
row["negexc_neg_mean"] = annotation_dic[center_str]["negexc_neg_mean"]
163+
rows.append(row)
164+
165+
df = pd.DataFrame(rows)
166+
return df
139167

140168

141169
def evaluate_marker_annotation(
@@ -181,7 +209,8 @@ def evaluate_marker_annotation(
181209
subtypes = COCHLEAE[cochlea]["subtype"]
182210
subtype_str = "_".join(subtypes)
183211
out_path = os.path.join(output_dir, f"{cochlea_str}_{subtype_str}_{seg_string}.tsv")
184-
if os.path.exists(out_path) and not force:
212+
annot_out = os.path.join(output_dir, f"{cochlea_str}_{subtype_str}_{seg_string}_annotations.tsv")
213+
if os.path.exists(out_path) and os.path.exists(annot_out) and not force:
185214
continue
186215

187216
# Get the segmentation data and table.
@@ -198,6 +227,7 @@ def evaluate_marker_annotation(
198227
intensity_mode = COCHLEAE[cochlea]["intensity"]
199228

200229
# iterate through subtypes
230+
annot_table = None
201231
for subtype in subtypes:
202232
pattern = subtype
203233
if intensity_mode == "ratio":
@@ -218,8 +248,14 @@ def evaluate_marker_annotation(
218248
print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.")
219249

220250
# Find the thresholds from the annotated blocks and save them if specified.
221-
intensity_dic = find_thresholds(cochlea_annotations, cochlea, data_seg,
251+
intensity_dic, annot_dic = find_thresholds(cochlea_annotations, cochlea, data_seg,
222252
table_measurement, column=column, pattern=pattern)
253+
254+
if annot_table is None:
255+
annot_table = get_annotation_table(annot_dic, subtype)
256+
else:
257+
annot_table = pd.concat([annot_table, get_annotation_table(annot_dic, subtype)], ignore_index=True)
258+
223259
if threshold_save_dir is not None:
224260
os.makedirs(threshold_save_dir, exist_ok=True)
225261
threshold_out_path = os.path.join(threshold_save_dir, f"{cochlea_str}_{subtype}_{seg_string}.json")
@@ -241,7 +277,11 @@ def evaluate_marker_annotation(
241277

242278
# Save the table with positives / negatives for all SGNs.
243279
os.makedirs(output_dir, exist_ok=True)
244-
table_seg.to_csv(out_path, sep="\t", index=False)
280+
281+
if not os.path.exists(out_path) or force:
282+
table_seg.to_csv(out_path, sep="\t", index=False)
283+
if not os.path.exists(annot_out) or force:
284+
annot_table.to_csv(annot_out, sep="\t", index=False)
245285

246286

247287
def main():

0 commit comments

Comments
 (0)