diff --git a/flamingo_tools/segmentation/ihc_synapse_postprocessing.py b/flamingo_tools/segmentation/ihc_synapse_postprocessing.py new file mode 100644 index 0000000..4a96a92 --- /dev/null +++ b/flamingo_tools/segmentation/ihc_synapse_postprocessing.py @@ -0,0 +1,182 @@ +from typing import List, Tuple + +import numpy as np +import pandas as pd + + +def find_overlapping_masks( + arr_base: np.ndarray, + arr_ref: np.ndarray, + label_id_base: int, + running_label_id: int, + min_overlap: float = 0.5, +) -> Tuple[List[dict], int]: + """Find overlapping masks between a base array and a reference array. + A label id of the base array is supplied and all unique IDs of the + reference array are checked for a minimal overlap. + Returns a list of all label IDs of the reference fulfilling this criteria. + + Args: + arr_base: 3D array acting as base. + arr_ref: 3D array acting as reference. + label_id_base: Value of instance segmentation in base array. + running_label_id: Unique label id for array, which replaces instance in base array. + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. + + Returns: + List of dictionaries containing reference label ID and new label ID in base array. + The updated label ID for new arrays in base array. + """ + edit_labels = [] + # base array containing only segmentation with too many synapses + arr_base[arr_base != label_id_base] = 0 + if np.count_nonzero(arr_base) == 0: + raise ValueError(f"Label id {label_id_base} not found in array. Wrong input?") + arr_base = arr_base.astype(bool) + + edit_labels = [] + # iterate through segmentation ids in reference mask + ref_ids = np.unique(arr_ref)[1:] + for ref_id in ref_ids: + arr_ref_instance = arr_ref.copy() + arr_ref_instance[arr_ref_instance != ref_id] = 0 + arr_ref_instance = arr_ref_instance.astype(bool) + + intersection = np.logical_and(arr_ref_instance, arr_base) + overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance) + if overlap_ratio >= min_overlap: + edit_labels.append({"ref_id": ref_id, + "new_label": running_label_id}) + running_label_id += 1 + + return edit_labels, running_label_id + + +def replace_masks( + arr_base: np.ndarray, + arr_ref: np.ndarray, + label_id_base: int, + edit_labels: List[dict], +) -> np.ndarray: + """Replace mask in base array with multiple masks from reference array. + + Args: + data_base: Base array. + data_ref: Reference array. + label_id_base: Value of instance segmentation in base array to be replaced. + edit_labels: List of dictionaries containing reference labels and new label ID. + + Returns: + Base array with updated content. + """ + print(f"Replacing {len(edit_labels)} instances") + arr_base[arr_base == label_id_base] = 0 + for edit_dic in edit_labels: + # bool array for new mask + data_ref_id = arr_ref.copy() + data_ref_id[data_ref_id != edit_dic["ref_id"]] = 0 + bool_ref = data_ref_id.astype(bool) + + arr_base[bool_ref] = edit_dic["new_label"] + return arr_base + + +def postprocess_ihc_synapse_crop( + data_base: np.typing.ArrayLike, + data_ref: np.typing.ArrayLike, + table_base: pd.DataFrame, + synapse_limit: int = 25, + min_overlap: float = 0.5, +) -> np.typing.ArrayLike: + """Postprocess IHC segmentation based on number of synapse per IHC count. + Segmentations from a base segmentation are analysed and replaced with + instances from a reference segmentation, if suitable instances overlap with + the base segmentation. + + Args: + data_base_: Base array. + data_ref_: Reference array. + table_base: Segmentation table of base segmentation with synapse per IHC counts. + synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation. + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. + + Returns: + Base array with updated content. + """ + # filter out problematic IHC segmentation + table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit] + + running_label_id = int(table_base["label_id"].max() + 1) + min_overlap = 0.5 + edit_labels = [] + + seg_ids_base = np.unique(data_base)[1:] + for seg_id_base in seg_ids_base: + if seg_id_base in list(table_edit["label_id"]): + + edit_labels, running_label_id = find_overlapping_masks( + data_base.copy(), data_ref.copy(), seg_id_base, + running_label_id, min_overlap=min_overlap, + ) + + if len(edit_labels) > 1: + data_base = replace_masks(data_base, data_ref, seg_id_base, edit_labels) + return data_base + + +def postprocess_ihc_synapse( + data_base: np.typing.ArrayLike, + data_ref: np.typing.ArrayLike, + table_base: pd.DataFrame, + synapse_limit: int = 25, + min_overlap: float = 0.5, + roi_pad: int = 40, + resolution: float = 0.38, +) -> np.typing.ArrayLike: + """Postprocess IHC segmentation based on number of synapse per IHC count. + Segmentations from a base segmentation are analysed and replaced with + instances from a reference segmentation, if suitable instances overlap with + the base segmentation. + + Args: + data_base: Base array. + data_ref: Reference array. + table_base: Segmentation table of base segmentation with synapse per IHC counts. + synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation. + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. + roi_pad: Padding added to bounding box to analyze overlapping segmentation masks in a ROI. + resolution: Resolution of pixels in µm. + + Returns: + Base array with updated content. + """ + # filter out problematic IHC segmentation + table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit] + + running_label_id = int(table_base["label_id"].max() + 1) + + for _, row in table_edit.iterrows(): + # access array in image space (pixels) + coords_max = [row["bb_max_x"], row["bb_max_y"], row["bb_max_z"]] + coords_max = [int(round(c / resolution)) for c in coords_max] + coords_min = [row["bb_min_x"], row["bb_min_y"], row["bb_min_z"]] + coords_min = [int(round(c / resolution)) for c in coords_min] + + coords_max.reverse() + coords_min.reverse() + roi = tuple(slice(cmin - roi_pad, cmax + roi_pad) for cmax, cmin in zip(coords_max, coords_min)) + + roi_base = data_base[roi] + roi_ref = data_ref[roi] + label_id_base = row["label_id"] + + edit_labels, running_label_id = find_overlapping_masks( + roi_base.copy(), roi_ref.copy(), label_id_base, + running_label_id, min_overlap=min_overlap, + ) + + if len(edit_labels) > 1: + roi_base = replace_masks(roi_base, roi_ref, label_id_base, edit_labels) + data_base[roi] = roi_base + + return data_base diff --git a/scripts/measurements/measure_synapses.py b/scripts/measurements/measure_synapses.py index d61ab64..f53ec7a 100644 --- a/scripts/measurements/measure_synapses.py +++ b/scripts/measurements/measure_synapses.py @@ -11,12 +11,19 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): s3 = create_s3_target() - cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R'] - synapse_table_name = "synapse_v3_ihc_v4" - ihc_table_name = "IHC_v4" + cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R', 'M_AMD_OTOF1_L'] results = {} for cochlea in cochleae: + synapse_table_name = "synapse_v3_ihc_v4c" + ihc_table_name = "IHC_v4c" + component_id = [1] + + if cochlea == 'M_AMD_OTOF1_L': + synapse_table_name = "synapse_v3_ihc_v4b" + ihc_table_name = "IHC_v4b" + component_id = [3, 11] + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") info = json.loads(content.read()) sources = info["sources"] @@ -38,8 +45,7 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): ihc_table = pd.read_csv(table_content, sep="\t") # Keep only the synapses that were matched to a valid IHC. - component_id = 1 - valid_ihcs = ihc_table.label_id[ihc_table.component_labels == component_id].values.astype("int") + valid_ihcs = ihc_table.label_id[ihc_table.component_labels.isin(component_id)].values.astype("int") valid_syn_table = syn_table[syn_table.matched_ihc.isin(valid_ihcs)] n_synapses = len(valid_syn_table) diff --git a/scripts/prediction/postprocess_ihc_synapse.py b/scripts/prediction/postprocess_ihc_synapse.py new file mode 100644 index 0000000..9f11806 --- /dev/null +++ b/scripts/prediction/postprocess_ihc_synapse.py @@ -0,0 +1,72 @@ +"""This script post-processes IHC segmentation with too many synapses based on a base segmentation and a reference. +""" +import argparse + +import imageio.v3 as imageio +import pandas as pd +from elf.io import open_file + +import flamingo_tools.segmentation.ihc_synapse_postprocessing as ihc_synapse_postprocessing +from flamingo_tools.file_utils import read_image_data + + +def main(): + parser = argparse.ArgumentParser( + description="Script to postprocess IHC segmentation based on the number of synapses per IHC.") + + parser.add_argument('--base_path', type=str, required=True, help="Base segmentation. WARNING: Will be edited.") + parser.add_argument('--ref_path', type=str, required=True, help="Reference segmentation.") + parser.add_argument('--out_path_tif', type=str, default=None, help="Output segmentation for tif output.") + + parser.add_argument('--base_table', type=str, required=True, help="Synapse per IHC table of base segmentation.") + + parser.add_argument("--base_key", type=str, default=None, + help="Input key for data in base segmentation.") + parser.add_argument("--ref_key", type=str, default=None, + help="Input key for data in reference segmentation.") + + parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.") + parser.add_argument("--tif", action="store_true", help="Store output as tif file.") + parser.add_argument("--crop", action="store_true", help="Process crop of original array.") + + parser.add_argument("--s3", action="store_true", help="Use S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + if args.tif: + if args.out_path_tif is None: + raise ValueError("Specify out_path_tif for saving TIF file.") + + if args.base_key is None: + data_base = read_image_data(args.base_path, args.base_key) + else: + data_base = open_file(args.base_path, "a")[args.base_key] + data_ref = read_image_data(args.ref_path, args.ref_key) + + with open(args.base_table, "r") as f: + table_base = pd.read_csv(f, sep="\t") + + if args.crop: + output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse_crop( + data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5, + ) + else: + output_ = ihc_synapse_postprocessing.postprocess_ihc_synapse( + data_base, data_ref, table_base=table_base, synapse_limit=25, min_overlap=0.5, + resolution=0.38, roi_pad=40, + ) + + if args.tif: + imageio.imwrite(args.out_path, output_, compression="zlib") + + +if __name__ == "__main__": + + main()