Skip to content

Commit d6dbca4

Browse files
committed
Postprocess IHC segmentation based on high synapse count
1 parent 301e018 commit d6dbca4

File tree

2 files changed

+251
-0
lines changed

2 files changed

+251
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
7+
def find_overlapping_masks(
8+
arr_base: np.ndarray,
9+
arr_ref: np.ndarray,
10+
label_id_base: int,
11+
running_label_id: int,
12+
min_overlap: float = 0.5,
13+
) -> Tuple[List[dict], int]:
14+
"""Find overlapping masks between a base array and a reference array.
15+
A label id of the base array is supplied and all unique IDs of the
16+
reference array are checked for a minimal overlap.
17+
Returns a list of all label IDs of the reference fulfilling this criteria.
18+
19+
Args:
20+
arr_base: 3D array acting as base.
21+
arr_ref: 3D array acting as reference.
22+
label_id_base: Value of instance segmentation in base array.
23+
running_label_id: Unique label id for array, which replaces instance in base array.
24+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
25+
26+
Returns:
27+
List of dictionaries containing reference label ID and new label ID in base array.
28+
The updated label ID for new arrays in base array.
29+
"""
30+
edit_labels = []
31+
# base array containing only segmentation with too many synapses
32+
arr_base[arr_base != label_id_base] = 0
33+
arr_base = arr_base.astype(bool)
34+
35+
edit_labels = []
36+
# iterate through segmentation ids in reference mask
37+
ref_ids = np.unique(arr_ref)[1:]
38+
for ref_id in ref_ids:
39+
arr_ref_instance = arr_ref.copy()
40+
arr_ref_instance[arr_ref_instance != ref_id] = 0
41+
arr_ref_instance = arr_ref_instance.astype(bool)
42+
43+
intersection = np.logical_and(arr_ref_instance, arr_base)
44+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
45+
if overlap_ratio >= min_overlap:
46+
edit_labels.append({"ref_id": ref_id,
47+
"new_label": running_label_id})
48+
running_label_id += 1
49+
50+
return edit_labels, running_label_id
51+
52+
53+
def replace_masks(
54+
data_base: np.ndarray,
55+
data_ref: np.ndarray,
56+
label_id_base: int,
57+
edit_labels: List[dict],
58+
) -> np.ndarray:
59+
"""Replace mask in base array with multiple masks from reference array.
60+
61+
Args:
62+
data_base: Base array.
63+
data_ref: Reference array.
64+
label_id_base: Value of instance segmentation in base array to be replaced.
65+
edit_labels: List of dictionaries containing reference labels and new label ID.
66+
67+
Returns:
68+
Base array with updated content.
69+
"""
70+
print(f"Replacing {len(edit_labels)} instances")
71+
data_base[data_base == label_id_base] = 0
72+
for edit_dic in edit_labels:
73+
# bool array for new mask
74+
data_ref_id = data_ref.copy()
75+
data_ref_id[data_ref_id != edit_dic["ref_id"]] = 0
76+
arr_ref = data_ref_id.astype(bool)
77+
78+
data_base[arr_ref] = edit_dic["new_label"]
79+
return data_base
80+
81+
82+
def postprocess_ihc_synapse_crop(
83+
data_base_: np.ndarray,
84+
data_ref_: np.ndarray,
85+
table_base: pd.DataFrame,
86+
synapse_limit: int = 25,
87+
min_overlap: float = 0.5,
88+
) -> np.ndarray:
89+
"""Postprocess IHC segmentation based on number of synapse per IHC count.
90+
Segmentations from a base segmentation are analysed and replaced with
91+
instances from a reference segmentation, if suitable instances overlap with
92+
the base segmentation.
93+
94+
Args:
95+
data_base_: Base array.
96+
data_ref_: Reference array.
97+
table_base: Segmentation table of base segmentation with synapse per IHC counts.
98+
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
99+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
100+
101+
Returns:
102+
Base array with updated content.
103+
"""
104+
# filter out problematic IHC segmentation
105+
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]
106+
107+
running_label_id = int(table_base["label_id"].max() + 1)
108+
min_overlap = 0.5
109+
edit_labels = []
110+
data_base = data_base_.copy()
111+
112+
seg_ids_base = np.unique(data_base)[1:]
113+
for seg_id_base in seg_ids_base:
114+
if seg_id_base in list(table_edit["label_id"]):
115+
116+
edit_labels, running_label_id = find_overlapping_masks(
117+
data_base.copy(), data_ref_.copy(), seg_id_base,
118+
running_label_id, min_overlap=min_overlap,
119+
)
120+
121+
if len(edit_labels) > 1:
122+
data_base = replace_masks(data_base, data_ref_, seg_id_base, edit_labels)
123+
return data_base
124+
125+
126+
def postprocess_ihc_synapse(
127+
data_base_: np.ndarray,
128+
data_ref_: np.ndarray,
129+
table_base: pd.DataFrame,
130+
synapse_limit: int = 25,
131+
min_overlap: float = 0.5,
132+
roi_pad: int = 40,
133+
resolution: float = 0.38,
134+
) -> np.typing.ArrayLike:
135+
"""Postprocess IHC segmentation based on number of synapse per IHC count.
136+
Segmentations from a base segmentation are analysed and replaced with
137+
instances from a reference segmentation, if suitable instances overlap with
138+
the base segmentation.
139+
140+
Args:
141+
data_base_: Base array.
142+
data_ref_: Reference array.
143+
table_base: Segmentation table of base segmentation with synapse per IHC counts.
144+
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
145+
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
146+
roi_pad: Padding added to bounding box to analyze overlapping segmentation masks in a ROI.
147+
resolution: Resolution of pixels in µm.
148+
149+
Returns:
150+
Base array with updated content.
151+
"""
152+
# filter out problematic IHC segmentation
153+
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]
154+
155+
running_label_id = int(table_base["label_id"].max() + 1)
156+
157+
for _, row in table_edit.iterrows():
158+
# access array in image space (pixels)
159+
coords_max = [row["bb_max_x"], row["bb_max_y"], row["bb_max_z"]]
160+
coords_max = [int(round(c / resolution)) for c in coords_max]
161+
coords_min = [row["bb_min_x"], row["bb_min_y"], row["bb_min_z"]]
162+
coords_min = [int(round(c / resolution)) for c in coords_min]
163+
roi = tuple(slice(cmin - roi_pad, cmax + roi_pad) for cmax, cmin in zip(coords_max, coords_min))
164+
165+
data_base = data_base_[roi]
166+
data_ref = data_ref_[roi]
167+
label_id_base = row["label_id"]
168+
169+
edit_labels, running_label_id = find_overlapping_masks(
170+
data_base.copy(), data_ref.copy(), label_id_base,
171+
running_label_id, min_overlap=min_overlap,
172+
)
173+
174+
if len(edit_labels) > 1:
175+
data_base = replace_masks(data_base, data_ref, label_id_base, edit_labels)
176+
data_base_[roi] = data_base
177+
178+
return data_base_
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""This script post-processes IHC segmentation with too many synapses based on a base segmentation and a reference.
2+
"""
3+
import argparse
4+
5+
import imageio.v3 as imageio
6+
import pandas as pd
7+
import zarr
8+
from elf.io import open_file
9+
import flamingo_tools.segmentation.ihc_synapse_postprocessing as ihc_synapse_postprocessing
10+
from flamingo_tools.file_utils import read_image_data
11+
12+
13+
def main():
14+
parser = argparse.ArgumentParser(
15+
description="Script to postprocess IHC segmentation based on the number of synapses per IHC.")
16+
17+
parser.add_argument('--base_path', type=str, required=True, help="Base segmentation.")
18+
parser.add_argument('--ref_path', type=str, required=True, help="Reference segmentation.")
19+
parser.add_argument('--out_path', type=str, required=True, help="Output segmentation.")
20+
21+
parser.add_argument('--base_table', type=str, required=True, help="Synapse per IHC table of base segmentation.")
22+
23+
parser.add_argument("--base_key", type=str, default=None,
24+
help="Input key for data in base segmentation.")
25+
parser.add_argument("--ref_key", type=str, default=None,
26+
help="Input key for data in reference segmentation.")
27+
parser.add_argument("--out_key", type=str, default="segmentation",
28+
help="Input key for data in output file.")
29+
30+
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.")
31+
parser.add_argument("--tif", action="store_true", help="Store output as tif file.")
32+
parser.add_argument("--crop", action="store_true", help="Process crop of original array.")
33+
34+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
35+
parser.add_argument("--s3_credentials", type=str, default=None,
36+
help="Input file containing S3 credentials. "
37+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
38+
parser.add_argument("--s3_bucket_name", type=str, default=None,
39+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
40+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
41+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
42+
43+
args = parser.parse_args()
44+
45+
if args.base_key is None:
46+
data_base_ = read_image_data(args.base_path, args.base_key)
47+
else:
48+
data_base_ = open_file(args.base_path, "a")[args.base_key]
49+
data_ref_ = read_image_data(args.ref_path, args.ref_key)
50+
51+
with open(args.base_table, "r") as f:
52+
table_base = pd.read_csv(f, sep="\t")
53+
54+
if args.crop:
55+
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse_crop(
56+
data_base_, data_ref_, table_base=table_base, synapse_limit=25, min_overlap=0.5,
57+
)
58+
else:
59+
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse(
60+
data_base_, data_ref_, table_base=table_base, synapse_limit=25, min_overlap=0.5,
61+
resolution=0.38, roi_buffer=40,
62+
)
63+
64+
if args.tif:
65+
imageio.imwrite(args.out_path, output_, compression="zlib")
66+
else:
67+
with zarr.open(args.out_path, mode="a") as f_out:
68+
f_out.create_dataset(args.out_key, data=output_, compression="gzip")
69+
70+
71+
if __name__ == "__main__":
72+
73+
main()

0 commit comments

Comments
 (0)