Skip to content

Commit 39b313f

Browse files
authored
Merge pull request #52 from computational-cell-analytics/postprocess_ihc_synapse
Post-process the IHC segmentation of instances with a high synapse count (greater than 25). Use a reference segmentation to edit a base segmentation and reduce the number of merged IHC instances. The script searches a segmentation table for IHC instances with a high synapse count. Then, it changes the IHC segmentation in place by replacing instances of the base segmentation with instances from the reference segmentation if they overlap and the reference segmentation contains more than two instances within the same region of interest.
2 parents 6653886 + 623a467 commit 39b313f

File tree

3 files changed

+265
-5
lines changed

3 files changed

+265
-5
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
7+
def find_overlapping_masks(
8+
arr_base: np.ndarray,
9+
arr_ref: np.ndarray,
10+
label_id_base: int,
11+
running_label_id: int,
12+
min_overlap: float = 0.5,
13+
) -> Tuple[List[dict], int]:
14+
"""Find overlapping masks between a base array and a reference array.
15+
A label id of the base array is supplied and all unique IDs of the
16+
reference array are checked for a minimal overlap.
17+
Returns a list of all label IDs of the reference fulfilling this criteria.
18+
19+
Args:
20+
arr_base: 3D array acting as base.
21+
arr_ref: 3D array acting as reference.
22+
label_id_base: Value of instance segmentation in base array.
23+
running_label_id: Unique label id for array, which replaces instance in base array.
24+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
25+
26+
Returns:
27+
List of dictionaries containing reference label ID and new label ID in base array.
28+
The updated label ID for new arrays in base array.
29+
"""
30+
edit_labels = []
31+
# base array containing only segmentation with too many synapses
32+
arr_base[arr_base != label_id_base] = 0
33+
if np.count_nonzero(arr_base) == 0:
34+
raise ValueError(f"Label id {label_id_base} not found in array. Wrong input?")
35+
arr_base = arr_base.astype(bool)
36+
37+
edit_labels = []
38+
# iterate through segmentation ids in reference mask
39+
ref_ids = np.unique(arr_ref)[1:]
40+
for ref_id in ref_ids:
41+
arr_ref_instance = arr_ref.copy()
42+
arr_ref_instance[arr_ref_instance != ref_id] = 0
43+
arr_ref_instance = arr_ref_instance.astype(bool)
44+
45+
intersection = np.logical_and(arr_ref_instance, arr_base)
46+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
47+
if overlap_ratio >= min_overlap:
48+
edit_labels.append({"ref_id": ref_id,
49+
"new_label": running_label_id})
50+
running_label_id += 1
51+
52+
return edit_labels, running_label_id
53+
54+
55+
def replace_masks(
56+
arr_base: np.ndarray,
57+
arr_ref: np.ndarray,
58+
label_id_base: int,
59+
edit_labels: List[dict],
60+
) -> np.ndarray:
61+
"""Replace mask in base array with multiple masks from reference array.
62+
63+
Args:
64+
data_base: Base array.
65+
data_ref: Reference array.
66+
label_id_base: Value of instance segmentation in base array to be replaced.
67+
edit_labels: List of dictionaries containing reference labels and new label ID.
68+
69+
Returns:
70+
Base array with updated content.
71+
"""
72+
print(f"Replacing {len(edit_labels)} instances")
73+
arr_base[arr_base == label_id_base] = 0
74+
for edit_dic in edit_labels:
75+
# bool array for new mask
76+
data_ref_id = arr_ref.copy()
77+
data_ref_id[data_ref_id != edit_dic["ref_id"]] = 0
78+
bool_ref = data_ref_id.astype(bool)
79+
80+
arr_base[bool_ref] = edit_dic["new_label"]
81+
return arr_base
82+
83+
84+
def postprocess_ihc_synapse_crop(
85+
data_base: np.typing.ArrayLike,
86+
data_ref: np.typing.ArrayLike,
87+
table_base: pd.DataFrame,
88+
synapse_limit: int = 25,
89+
min_overlap: float = 0.5,
90+
) -> np.typing.ArrayLike:
91+
"""Postprocess IHC segmentation based on number of synapse per IHC count.
92+
Segmentations from a base segmentation are analysed and replaced with
93+
instances from a reference segmentation, if suitable instances overlap with
94+
the base segmentation.
95+
96+
Args:
97+
data_base_: Base array.
98+
data_ref_: Reference array.
99+
table_base: Segmentation table of base segmentation with synapse per IHC counts.
100+
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
101+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
102+
103+
Returns:
104+
Base array with updated content.
105+
"""
106+
# filter out problematic IHC segmentation
107+
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]
108+
109+
running_label_id = int(table_base["label_id"].max() + 1)
110+
min_overlap = 0.5
111+
edit_labels = []
112+
113+
seg_ids_base = np.unique(data_base)[1:]
114+
for seg_id_base in seg_ids_base:
115+
if seg_id_base in list(table_edit["label_id"]):
116+
117+
edit_labels, running_label_id = find_overlapping_masks(
118+
data_base.copy(), data_ref.copy(), seg_id_base,
119+
running_label_id, min_overlap=min_overlap,
120+
)
121+
122+
if len(edit_labels) > 1:
123+
data_base = replace_masks(data_base, data_ref, seg_id_base, edit_labels)
124+
return data_base
125+
126+
127+
def postprocess_ihc_synapse(
128+
data_base: np.typing.ArrayLike,
129+
data_ref: np.typing.ArrayLike,
130+
table_base: pd.DataFrame,
131+
synapse_limit: int = 25,
132+
min_overlap: float = 0.5,
133+
roi_pad: int = 40,
134+
resolution: float = 0.38,
135+
) -> np.typing.ArrayLike:
136+
"""Postprocess IHC segmentation based on number of synapse per IHC count.
137+
Segmentations from a base segmentation are analysed and replaced with
138+
instances from a reference segmentation, if suitable instances overlap with
139+
the base segmentation.
140+
141+
Args:
142+
data_base: Base array.
143+
data_ref: Reference array.
144+
table_base: Segmentation table of base segmentation with synapse per IHC counts.
145+
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
146+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
147+
roi_pad: Padding added to bounding box to analyze overlapping segmentation masks in a ROI.
148+
resolution: Resolution of pixels in µm.
149+
150+
Returns:
151+
Base array with updated content.
152+
"""
153+
# filter out problematic IHC segmentation
154+
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]
155+
156+
running_label_id = int(table_base["label_id"].max() + 1)
157+
158+
for _, row in table_edit.iterrows():
159+
# access array in image space (pixels)
160+
coords_max = [row["bb_max_x"], row["bb_max_y"], row["bb_max_z"]]
161+
coords_max = [int(round(c / resolution)) for c in coords_max]
162+
coords_min = [row["bb_min_x"], row["bb_min_y"], row["bb_min_z"]]
163+
coords_min = [int(round(c / resolution)) for c in coords_min]
164+
165+
coords_max.reverse()
166+
coords_min.reverse()
167+
roi = tuple(slice(cmin - roi_pad, cmax + roi_pad) for cmax, cmin in zip(coords_max, coords_min))
168+
169+
roi_base = data_base[roi]
170+
roi_ref = data_ref[roi]
171+
label_id_base = row["label_id"]
172+
173+
edit_labels, running_label_id = find_overlapping_masks(
174+
roi_base.copy(), roi_ref.copy(), label_id_base,
175+
running_label_id, min_overlap=min_overlap,
176+
)
177+
178+
if len(edit_labels) > 1:
179+
roi_base = replace_masks(roi_base, roi_ref, label_id_base, edit_labels)
180+
data_base[roi] = roi_base
181+
182+
return data_base

scripts/measurements/measure_synapses.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,19 @@
1111

1212
def check_project(plot=False, save_ihc_table=False, max_dist=None):
1313
s3 = create_s3_target()
14-
cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R']
15-
synapse_table_name = "synapse_v3_ihc_v4"
16-
ihc_table_name = "IHC_v4"
14+
cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R', 'M_AMD_OTOF1_L']
1715

1816
results = {}
1917
for cochlea in cochleae:
18+
synapse_table_name = "synapse_v3_ihc_v4c"
19+
ihc_table_name = "IHC_v4c"
20+
component_id = [1]
21+
22+
if cochlea == 'M_AMD_OTOF1_L':
23+
synapse_table_name = "synapse_v3_ihc_v4b"
24+
ihc_table_name = "IHC_v4b"
25+
component_id = [3, 11]
26+
2027
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
2128
info = json.loads(content.read())
2229
sources = info["sources"]
@@ -38,8 +45,7 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None):
3845
ihc_table = pd.read_csv(table_content, sep="\t")
3946

4047
# Keep only the synapses that were matched to a valid IHC.
41-
component_id = 1
42-
valid_ihcs = ihc_table.label_id[ihc_table.component_labels == component_id].values.astype("int")
48+
valid_ihcs = ihc_table.label_id[ihc_table.component_labels.isin(component_id)].values.astype("int")
4349

4450
valid_syn_table = syn_table[syn_table.matched_ihc.isin(valid_ihcs)]
4551
n_synapses = len(valid_syn_table)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""This script post-processes IHC segmentation with too many synapses based on a base segmentation and a reference.
2+
"""
3+
import argparse
4+
5+
import imageio.v3 as imageio
6+
import pandas as pd
7+
from elf.io import open_file
8+
9+
import flamingo_tools.segmentation.ihc_synapse_postprocessing as ihc_synapse_postprocessing
10+
from flamingo_tools.file_utils import read_image_data
11+
12+
13+
def main():
14+
parser = argparse.ArgumentParser(
15+
description="Script to postprocess IHC segmentation based on the number of synapses per IHC.")
16+
17+
parser.add_argument('--base_path', type=str, required=True, help="Base segmentation. WARNING: Will be edited.")
18+
parser.add_argument('--ref_path', type=str, required=True, help="Reference segmentation.")
19+
parser.add_argument('--out_path_tif', type=str, default=None, help="Output segmentation for tif output.")
20+
21+
parser.add_argument('--base_table', type=str, required=True, help="Synapse per IHC table of base segmentation.")
22+
23+
parser.add_argument("--base_key", type=str, default=None,
24+
help="Input key for data in base segmentation.")
25+
parser.add_argument("--ref_key", type=str, default=None,
26+
help="Input key for data in reference segmentation.")
27+
28+
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.")
29+
parser.add_argument("--tif", action="store_true", help="Store output as tif file.")
30+
parser.add_argument("--crop", action="store_true", help="Process crop of original array.")
31+
32+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
33+
parser.add_argument("--s3_credentials", type=str, default=None,
34+
help="Input file containing S3 credentials. "
35+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
36+
parser.add_argument("--s3_bucket_name", type=str, default=None,
37+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
38+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
39+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
40+
41+
args = parser.parse_args()
42+
43+
if args.tif:
44+
if args.out_path_tif is None:
45+
raise ValueError("Specify out_path_tif for saving TIF file.")
46+
47+
if args.base_key is None:
48+
data_base = read_image_data(args.base_path, args.base_key)
49+
else:
50+
data_base = open_file(args.base_path, "a")[args.base_key]
51+
data_ref = read_image_data(args.ref_path, args.ref_key)
52+
53+
with open(args.base_table, "r") as f:
54+
table_base = pd.read_csv(f, sep="\t")
55+
56+
if args.crop:
57+
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse_crop(
58+
data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5,
59+
)
60+
else:
61+
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse(
62+
data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5,
63+
resolution=0.38, roi_pad=40,
64+
)
65+
66+
if args.tif:
67+
imageio.imwrite(args.out_path, output_, compression="zlib")
68+
69+
70+
if __name__ == "__main__":
71+
72+
main()

0 commit comments

Comments
 (0)