Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions flamingo_tools/segmentation/ihc_synapse_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import List, Tuple

import numpy as np
import pandas as pd


def find_overlapping_masks(
arr_base: np.ndarray,
arr_ref: np.ndarray,
label_id_base: int,
running_label_id: int,
min_overlap: float = 0.5,
) -> Tuple[List[dict], int]:
"""Find overlapping masks between a base array and a reference array.
A label id of the base array is supplied and all unique IDs of the
reference array are checked for a minimal overlap.
Returns a list of all label IDs of the reference fulfilling this criteria.
Args:
arr_base: 3D array acting as base.
arr_ref: 3D array acting as reference.
label_id_base: Value of instance segmentation in base array.
running_label_id: Unique label id for array, which replaces instance in base array.
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
Returns:
List of dictionaries containing reference label ID and new label ID in base array.
The updated label ID for new arrays in base array.
"""
edit_labels = []
# base array containing only segmentation with too many synapses
arr_base[arr_base != label_id_base] = 0
if np.count_nonzero(arr_base) == 0:
raise ValueError(f"Label id {label_id_base} not found in array. Wrong input?")
arr_base = arr_base.astype(bool)

edit_labels = []
# iterate through segmentation ids in reference mask
ref_ids = np.unique(arr_ref)[1:]
for ref_id in ref_ids:
arr_ref_instance = arr_ref.copy()
arr_ref_instance[arr_ref_instance != ref_id] = 0
arr_ref_instance = arr_ref_instance.astype(bool)

intersection = np.logical_and(arr_ref_instance, arr_base)
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
if overlap_ratio >= min_overlap:
edit_labels.append({"ref_id": ref_id,
"new_label": running_label_id})
running_label_id += 1

return edit_labels, running_label_id


def replace_masks(
arr_base: np.ndarray,
arr_ref: np.ndarray,
label_id_base: int,
edit_labels: List[dict],
) -> np.ndarray:
"""Replace mask in base array with multiple masks from reference array.
Args:
data_base: Base array.
data_ref: Reference array.
label_id_base: Value of instance segmentation in base array to be replaced.
edit_labels: List of dictionaries containing reference labels and new label ID.
Returns:
Base array with updated content.
"""
print(f"Replacing {len(edit_labels)} instances")
arr_base[arr_base == label_id_base] = 0
for edit_dic in edit_labels:
# bool array for new mask
data_ref_id = arr_ref.copy()
data_ref_id[data_ref_id != edit_dic["ref_id"]] = 0
bool_ref = data_ref_id.astype(bool)

arr_base[bool_ref] = edit_dic["new_label"]
return arr_base


def postprocess_ihc_synapse_crop(
data_base: np.typing.ArrayLike,
data_ref: np.typing.ArrayLike,
table_base: pd.DataFrame,
synapse_limit: int = 25,
min_overlap: float = 0.5,
) -> np.typing.ArrayLike:
"""Postprocess IHC segmentation based on number of synapse per IHC count.
Segmentations from a base segmentation are analysed and replaced with
instances from a reference segmentation, if suitable instances overlap with
the base segmentation.
Args:
data_base_: Base array.
data_ref_: Reference array.
table_base: Segmentation table of base segmentation with synapse per IHC counts.
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
Returns:
Base array with updated content.
"""
# filter out problematic IHC segmentation
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]

running_label_id = int(table_base["label_id"].max() + 1)
min_overlap = 0.5
edit_labels = []

seg_ids_base = np.unique(data_base)[1:]
for seg_id_base in seg_ids_base:
if seg_id_base in list(table_edit["label_id"]):

edit_labels, running_label_id = find_overlapping_masks(
data_base.copy(), data_ref.copy(), seg_id_base,
running_label_id, min_overlap=min_overlap,
)

if len(edit_labels) > 1:
data_base = replace_masks(data_base, data_ref, seg_id_base, edit_labels)
return data_base


def postprocess_ihc_synapse(
data_base: np.typing.ArrayLike,
data_ref: np.typing.ArrayLike,
table_base: pd.DataFrame,
synapse_limit: int = 25,
min_overlap: float = 0.5,
roi_pad: int = 40,
resolution: float = 0.38,
) -> np.typing.ArrayLike:
"""Postprocess IHC segmentation based on number of synapse per IHC count.
Segmentations from a base segmentation are analysed and replaced with
instances from a reference segmentation, if suitable instances overlap with
the base segmentation.
Args:
data_base: Base array.
data_ref: Reference array.
table_base: Segmentation table of base segmentation with synapse per IHC counts.
synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation.
min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement.
roi_pad: Padding added to bounding box to analyze overlapping segmentation masks in a ROI.
resolution: Resolution of pixels in µm.
Returns:
Base array with updated content.
"""
# filter out problematic IHC segmentation
table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit]

running_label_id = int(table_base["label_id"].max() + 1)

for _, row in table_edit.iterrows():
# access array in image space (pixels)
coords_max = [row["bb_max_x"], row["bb_max_y"], row["bb_max_z"]]
coords_max = [int(round(c / resolution)) for c in coords_max]
coords_min = [row["bb_min_x"], row["bb_min_y"], row["bb_min_z"]]
coords_min = [int(round(c / resolution)) for c in coords_min]

coords_max.reverse()
coords_min.reverse()
roi = tuple(slice(cmin - roi_pad, cmax + roi_pad) for cmax, cmin in zip(coords_max, coords_min))

roi_base = data_base[roi]
roi_ref = data_ref[roi]
label_id_base = row["label_id"]

edit_labels, running_label_id = find_overlapping_masks(
roi_base.copy(), roi_ref.copy(), label_id_base,
running_label_id, min_overlap=min_overlap,
)

if len(edit_labels) > 1:
roi_base = replace_masks(roi_base, roi_ref, label_id_base, edit_labels)
data_base[roi] = roi_base

return data_base
6 changes: 3 additions & 3 deletions scripts/measurements/measure_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

def check_project(plot=False, save_ihc_table=False, max_dist=None):
s3 = create_s3_target()
cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R']
synapse_table_name = "synapse_v3_ihc_v4"
ihc_table_name = "IHC_v4"
cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R', 'M_AMD_OTOF1_L']
synapse_table_name = "synapse_v3_ihc_v4c"
ihc_table_name = "IHC_v4c"

results = {}
for cochlea in cochleae:
Expand Down
70 changes: 70 additions & 0 deletions scripts/prediction/postprocess_ihc_synapse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""This script post-processes IHC segmentation with too many synapses based on a base segmentation and a reference.
"""
import argparse

import imageio.v3 as imageio
import pandas as pd
from elf.io import open_file

import flamingo_tools.segmentation.ihc_synapse_postprocessing as ihc_synapse_postprocessing
from flamingo_tools.file_utils import read_image_data


def main():
parser = argparse.ArgumentParser(
description="Script to postprocess IHC segmentation based on the number of synapses per IHC.")

parser.add_argument('--base_path', type=str, required=True, help="Base segmentation.")
parser.add_argument('--ref_path', type=str, required=True, help="Reference segmentation.")
parser.add_argument('--out_path', type=str, required=True, help="Output segmentation.")

parser.add_argument('--base_table', type=str, required=True, help="Synapse per IHC table of base segmentation.")

parser.add_argument("--base_key", type=str, default=None,
help="Input key for data in base segmentation.")
parser.add_argument("--ref_key", type=str, default=None,
help="Input key for data in reference segmentation.")
parser.add_argument("--out_key", type=str, default="segmentation",
help="Input key for data in output file.")

parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.")
parser.add_argument("--tif", action="store_true", help="Store output as tif file.")
parser.add_argument("--crop", action="store_true", help="Process crop of original array.")

parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
parser.add_argument("--s3_credentials", type=str, default=None,
help="Input file containing S3 credentials. "
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
parser.add_argument("--s3_bucket_name", type=str, default=None,
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
parser.add_argument("--s3_service_endpoint", type=str, default=None,
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")

args = parser.parse_args()

if args.base_key is None:
data_base = read_image_data(args.base_path, args.base_key)
else:
data_base = open_file(args.base_path, "a")[args.base_key]
data_ref = read_image_data(args.ref_path, args.ref_key)

with open(args.base_table, "r") as f:
table_base = pd.read_csv(f, sep="\t")

if args.crop:
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse_crop(
data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5,
)
else:
output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse(
data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5,
resolution=0.38, roi_pad=40,
)

if args.tif:
imageio.imwrite(args.out_path, output_, compression="zlib")


if __name__ == "__main__":

main()
Loading