|
| 1 | +from typing import List, Tuple |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | + |
| 7 | +def find_overlapping_masks( |
| 8 | + arr_base: np.ndarray, |
| 9 | + arr_ref: np.ndarray, |
| 10 | + label_id_base: int, |
| 11 | + running_label_id: int, |
| 12 | + min_overlap: float = 0.5, |
| 13 | +) -> Tuple[List[dict], int]: |
| 14 | + """Find overlapping masks between a base array and a reference array. |
| 15 | + A label id of the base array is supplied and all unique IDs of the |
| 16 | + reference array are checked for a minimal overlap. |
| 17 | + Returns a list of all label IDs of the reference fulfilling this criteria. |
| 18 | +
|
| 19 | + Args: |
| 20 | + arr_base: 3D array acting as base. |
| 21 | + arr_ref: 3D array acting as reference. |
| 22 | + label_id_base: Value of instance segmentation in base array. |
| 23 | + running_label_id: Unique label id for array, which replaces instance in base array. |
| 24 | + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. |
| 25 | +
|
| 26 | + Returns: |
| 27 | + List of dictionaries containing reference label ID and new label ID in base array. |
| 28 | + The updated label ID for new arrays in base array. |
| 29 | + """ |
| 30 | + edit_labels = [] |
| 31 | + # base array containing only segmentation with too many synapses |
| 32 | + arr_base[arr_base != label_id_base] = 0 |
| 33 | + if np.count_nonzero(arr_base) == 0: |
| 34 | + raise ValueError(f"Label id {label_id_base} not found in array. Wrong input?") |
| 35 | + arr_base = arr_base.astype(bool) |
| 36 | + |
| 37 | + edit_labels = [] |
| 38 | + # iterate through segmentation ids in reference mask |
| 39 | + ref_ids = np.unique(arr_ref)[1:] |
| 40 | + for ref_id in ref_ids: |
| 41 | + arr_ref_instance = arr_ref.copy() |
| 42 | + arr_ref_instance[arr_ref_instance != ref_id] = 0 |
| 43 | + arr_ref_instance = arr_ref_instance.astype(bool) |
| 44 | + |
| 45 | + intersection = np.logical_and(arr_ref_instance, arr_base) |
| 46 | + overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance) |
| 47 | + if overlap_ratio >= min_overlap: |
| 48 | + edit_labels.append({"ref_id": ref_id, |
| 49 | + "new_label": running_label_id}) |
| 50 | + running_label_id += 1 |
| 51 | + |
| 52 | + return edit_labels, running_label_id |
| 53 | + |
| 54 | + |
| 55 | +def replace_masks( |
| 56 | + arr_base: np.ndarray, |
| 57 | + arr_ref: np.ndarray, |
| 58 | + label_id_base: int, |
| 59 | + edit_labels: List[dict], |
| 60 | +) -> np.ndarray: |
| 61 | + """Replace mask in base array with multiple masks from reference array. |
| 62 | +
|
| 63 | + Args: |
| 64 | + data_base: Base array. |
| 65 | + data_ref: Reference array. |
| 66 | + label_id_base: Value of instance segmentation in base array to be replaced. |
| 67 | + edit_labels: List of dictionaries containing reference labels and new label ID. |
| 68 | +
|
| 69 | + Returns: |
| 70 | + Base array with updated content. |
| 71 | + """ |
| 72 | + print(f"Replacing {len(edit_labels)} instances") |
| 73 | + arr_base[arr_base == label_id_base] = 0 |
| 74 | + for edit_dic in edit_labels: |
| 75 | + # bool array for new mask |
| 76 | + data_ref_id = arr_ref.copy() |
| 77 | + data_ref_id[data_ref_id != edit_dic["ref_id"]] = 0 |
| 78 | + bool_ref = data_ref_id.astype(bool) |
| 79 | + |
| 80 | + arr_base[bool_ref] = edit_dic["new_label"] |
| 81 | + return arr_base |
| 82 | + |
| 83 | + |
| 84 | +def postprocess_ihc_synapse_crop( |
| 85 | + data_base: np.typing.ArrayLike, |
| 86 | + data_ref: np.typing.ArrayLike, |
| 87 | + table_base: pd.DataFrame, |
| 88 | + synapse_limit: int = 25, |
| 89 | + min_overlap: float = 0.5, |
| 90 | +) -> np.typing.ArrayLike: |
| 91 | + """Postprocess IHC segmentation based on number of synapse per IHC count. |
| 92 | + Segmentations from a base segmentation are analysed and replaced with |
| 93 | + instances from a reference segmentation, if suitable instances overlap with |
| 94 | + the base segmentation. |
| 95 | +
|
| 96 | + Args: |
| 97 | + data_base_: Base array. |
| 98 | + data_ref_: Reference array. |
| 99 | + table_base: Segmentation table of base segmentation with synapse per IHC counts. |
| 100 | + synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation. |
| 101 | + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. |
| 102 | +
|
| 103 | + Returns: |
| 104 | + Base array with updated content. |
| 105 | + """ |
| 106 | + # filter out problematic IHC segmentation |
| 107 | + table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit] |
| 108 | + |
| 109 | + running_label_id = int(table_base["label_id"].max() + 1) |
| 110 | + min_overlap = 0.5 |
| 111 | + edit_labels = [] |
| 112 | + |
| 113 | + seg_ids_base = np.unique(data_base)[1:] |
| 114 | + for seg_id_base in seg_ids_base: |
| 115 | + if seg_id_base in list(table_edit["label_id"]): |
| 116 | + |
| 117 | + edit_labels, running_label_id = find_overlapping_masks( |
| 118 | + data_base.copy(), data_ref.copy(), seg_id_base, |
| 119 | + running_label_id, min_overlap=min_overlap, |
| 120 | + ) |
| 121 | + |
| 122 | + if len(edit_labels) > 1: |
| 123 | + data_base = replace_masks(data_base, data_ref, seg_id_base, edit_labels) |
| 124 | + return data_base |
| 125 | + |
| 126 | + |
| 127 | +def postprocess_ihc_synapse( |
| 128 | + data_base: np.typing.ArrayLike, |
| 129 | + data_ref: np.typing.ArrayLike, |
| 130 | + table_base: pd.DataFrame, |
| 131 | + synapse_limit: int = 25, |
| 132 | + min_overlap: float = 0.5, |
| 133 | + roi_pad: int = 40, |
| 134 | + resolution: float = 0.38, |
| 135 | +) -> np.typing.ArrayLike: |
| 136 | + """Postprocess IHC segmentation based on number of synapse per IHC count. |
| 137 | + Segmentations from a base segmentation are analysed and replaced with |
| 138 | + instances from a reference segmentation, if suitable instances overlap with |
| 139 | + the base segmentation. |
| 140 | +
|
| 141 | + Args: |
| 142 | + data_base: Base array. |
| 143 | + data_ref: Reference array. |
| 144 | + table_base: Segmentation table of base segmentation with synapse per IHC counts. |
| 145 | + synapse_limit: Limit of synapses per IHC to consider replacement of base segmentation. |
| 146 | + min_overlap: Minimal fraction of overlap between ref and base isntances to consider replacement. |
| 147 | + roi_pad: Padding added to bounding box to analyze overlapping segmentation masks in a ROI. |
| 148 | + resolution: Resolution of pixels in µm. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + Base array with updated content. |
| 152 | + """ |
| 153 | + # filter out problematic IHC segmentation |
| 154 | + table_edit = table_base[table_base["syn_per_IHC"] >= synapse_limit] |
| 155 | + |
| 156 | + running_label_id = int(table_base["label_id"].max() + 1) |
| 157 | + |
| 158 | + for _, row in table_edit.iterrows(): |
| 159 | + # access array in image space (pixels) |
| 160 | + coords_max = [row["bb_max_x"], row["bb_max_y"], row["bb_max_z"]] |
| 161 | + coords_max = [int(round(c / resolution)) for c in coords_max] |
| 162 | + coords_min = [row["bb_min_x"], row["bb_min_y"], row["bb_min_z"]] |
| 163 | + coords_min = [int(round(c / resolution)) for c in coords_min] |
| 164 | + |
| 165 | + coords_max.reverse() |
| 166 | + coords_min.reverse() |
| 167 | + roi = tuple(slice(cmin - roi_pad, cmax + roi_pad) for cmax, cmin in zip(coords_max, coords_min)) |
| 168 | + |
| 169 | + roi_base = data_base[roi] |
| 170 | + roi_ref = data_ref[roi] |
| 171 | + label_id_base = row["label_id"] |
| 172 | + |
| 173 | + edit_labels, running_label_id = find_overlapping_masks( |
| 174 | + roi_base.copy(), roi_ref.copy(), label_id_base, |
| 175 | + running_label_id, min_overlap=min_overlap, |
| 176 | + ) |
| 177 | + |
| 178 | + if len(edit_labels) > 1: |
| 179 | + roi_base = replace_masks(roi_base, roi_ref, label_id_base, edit_labels) |
| 180 | + data_base[roi] = roi_base |
| 181 | + |
| 182 | + return data_base |
0 commit comments