| 
 | 1 | +import argparse  | 
 | 2 | +import json  | 
 | 3 | +import os  | 
 | 4 | +from typing import List, Optional  | 
 | 5 | + | 
 | 6 | +import pandas as pd  | 
 | 7 | + | 
 | 8 | +from flamingo_tools.s3_utils import get_s3_path  | 
 | 9 | +from flamingo_tools.file_utils import read_image_data  | 
 | 10 | +from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations  | 
 | 11 | + | 
 | 12 | +MARKER_DIR_SUBTYPE = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/SGN_subtypes"  | 
 | 13 | +# The cochlea for the CHReef analysis.  | 
 | 14 | + | 
 | 15 | +COCHLEAE = {  | 
 | 16 | +    "M_LR_000184_L": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},  | 
 | 17 | +    "M_LR_000184_R": {"seg_data": "SGN_v2", "subtype": ["Prph"], "output_seg": "SGN_v2b"},  | 
 | 18 | +    "M_LR_000099_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1", "Lypd1"]},  | 
 | 19 | +    "M_LR_000214_L": {"seg_data": "PV_SGN_v2", "subtype": ["Calb1"]},  | 
 | 20 | +}  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +def get_length_fraction_from_center(table, center_str):  | 
 | 24 | +    """Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances.  | 
 | 25 | +    """  | 
 | 26 | +    center_coord = tuple([int(c) for c in center_str.split("-")])  | 
 | 27 | +    (cx, cy, cz) = center_coord  | 
 | 28 | +    offset = 20  | 
 | 29 | +    subset = table[  | 
 | 30 | +        (cx - offset < table["anchor_x"]) &  | 
 | 31 | +        (table["anchor_x"] < cx + offset) &  | 
 | 32 | +        (cy - offset < table["anchor_y"]) &  | 
 | 33 | +        (table["anchor_y"] < cy + offset) &  | 
 | 34 | +        (cz - offset < table["anchor_z"]) &  | 
 | 35 | +        (table["anchor_z"] < cz + offset)  | 
 | 36 | +    ]  | 
 | 37 | +    length_fraction = list(subset["length_fraction"])  | 
 | 38 | +    length_fraction = float(sum(length_fraction) / len(length_fraction))  | 
 | 39 | +    return length_fraction  | 
 | 40 | + | 
 | 41 | + | 
 | 42 | +def apply_nearest_threshold(intensity_dic, table_seg, table_measurement, column="median", suffix="labels"):  | 
 | 43 | +    """Apply threshold to nearest segmentation instances.  | 
 | 44 | +    Crop centers are transformed into the "length fraction" parameter of the segmentation table.  | 
 | 45 | +    This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal"s canal.  | 
 | 46 | +    """  | 
 | 47 | +    # assign crop centers to length fraction of Rosenthal"s canal  | 
 | 48 | +    lf_intensity = {}  | 
 | 49 | +    for key in intensity_dic.keys():  | 
 | 50 | +        length_fraction = get_length_fraction_from_center(table_seg, key)  | 
 | 51 | +        intensity_dic[key]["length_fraction"] = length_fraction  | 
 | 52 | +        lf_intensity[length_fraction] = {"threshold": intensity_dic[key]["median_intensity"]}  | 
 | 53 | + | 
 | 54 | +    # get limits for checking marker thresholds  | 
 | 55 | +    lf_intensity = dict(sorted(lf_intensity.items()))  | 
 | 56 | +    lf_fractions = list(lf_intensity.keys())  | 
 | 57 | +    # start of cochlea  | 
 | 58 | +    lf_limits = [0]  | 
 | 59 | +    # half distance between block centers  | 
 | 60 | +    for i in range(len(lf_fractions) - 1):  | 
 | 61 | +        lf_limits.append((lf_fractions[i] + lf_fractions[i+1]) / 2)  | 
 | 62 | +    # end of cochlea  | 
 | 63 | +    lf_limits.append(1)  | 
 | 64 | + | 
 | 65 | +    marker_labels = [0 for _ in range(len(table_seg))]  | 
 | 66 | +    table_seg.loc[:, f"marker_{suffix}"] = marker_labels  | 
 | 67 | +    for num, fraction in enumerate(lf_fractions):  | 
 | 68 | +        subset_seg = table_seg[  | 
 | 69 | +            (table_seg["length_fraction"] > lf_limits[num]) &  | 
 | 70 | +            (table_seg["length_fraction"] < lf_limits[num + 1])  | 
 | 71 | +        ]  | 
 | 72 | +        # assign values based on limits  | 
 | 73 | +        threshold = lf_intensity[fraction]["threshold"]  | 
 | 74 | +        label_ids_seg = subset_seg["label_id"]  | 
 | 75 | + | 
 | 76 | +        subset_measurement = table_measurement[table_measurement["label_id"].isin(label_ids_seg)]  | 
 | 77 | +        subset_positive = subset_measurement[subset_measurement[column] >= threshold]  | 
 | 78 | +        subset_negative = subset_measurement[subset_measurement[column] < threshold]  | 
 | 79 | +        label_ids_pos = list(subset_positive["label_id"])  | 
 | 80 | +        label_ids_neg = list(subset_negative["label_id"])  | 
 | 81 | + | 
 | 82 | +        table_seg.loc[table_seg["label_id"].isin(label_ids_pos), f"marker_{suffix}"] = 1  | 
 | 83 | +        table_seg.loc[table_seg["label_id"].isin(label_ids_neg), f"marker_{suffix}"] = 2  | 
 | 84 | + | 
 | 85 | +    return table_seg  | 
 | 86 | + | 
 | 87 | + | 
 | 88 | +def find_thresholds(cochlea_annotations, cochlea, data_seg, table_measurement, column="median", pattern=None):  | 
 | 89 | +    # Find the median intensities by averaging the individual annotations for specific crops  | 
 | 90 | +    annotation_dics = {}  | 
 | 91 | +    annotated_centers = []  | 
 | 92 | +    for annotation_dir in cochlea_annotations:  | 
 | 93 | +        print(f"Localizing threshold with median intensities for {os.path.basename(annotation_dir)}.")  | 
 | 94 | +        annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg,  | 
 | 95 | +                                                     table_measurement, column=column, pattern=pattern)  | 
 | 96 | +        annotated_centers.extend(annotation_dic["center_strings"])  | 
 | 97 | +        annotation_dics[annotation_dir] = annotation_dic  | 
 | 98 | + | 
 | 99 | +    annotated_centers = list(set(annotated_centers))  | 
 | 100 | +    intensity_dic = {}  | 
 | 101 | +    # loop over all annotated blocks  | 
 | 102 | +    for annotated_center in annotated_centers:  | 
 | 103 | +        intensities = []  | 
 | 104 | +        annotator_success = []  | 
 | 105 | +        annotator_failure = []  | 
 | 106 | +        annotator_missing = []  | 
 | 107 | +        # loop over annotated block from single user  | 
 | 108 | +        for annotator_key in annotation_dics.keys():  | 
 | 109 | +            if annotated_center not in annotation_dics[annotator_key]["center_strings"]:  | 
 | 110 | +                annotator_missing.append(os.path.basename(annotator_key))  | 
 | 111 | +                continue  | 
 | 112 | +            else:  | 
 | 113 | +                median_intensity = annotation_dics[annotator_key][annotated_center]["median_intensity"]  | 
 | 114 | +                if median_intensity is None:  | 
 | 115 | +                    print(f"No threshold for {os.path.basename(annotator_key)} and crop {annotated_center}.")  | 
 | 116 | +                    annotator_failure.append(os.path.basename(annotator_key))  | 
 | 117 | +                else:  | 
 | 118 | +                    intensities.append(median_intensity)  | 
 | 119 | +                    annotator_success.append(os.path.basename(annotator_key))  | 
 | 120 | + | 
 | 121 | +        if len(intensities) == 0:  | 
 | 122 | +            print(f"No viable annotation for cochlea {cochlea} and crop {annotated_center}.")  | 
 | 123 | +            median_int_avg = None  | 
 | 124 | +        else:  | 
 | 125 | +            median_int_avg = float(sum(intensities) / len(intensities)),  | 
 | 126 | + | 
 | 127 | +        intensity_dic[annotated_center] = {  | 
 | 128 | +            "median_intensity": median_int_avg,  | 
 | 129 | +            "annotation_success": annotator_success,  | 
 | 130 | +            "annotation_failure": annotator_failure,  | 
 | 131 | +            "annotation_missing": annotator_missing,  | 
 | 132 | +        }  | 
 | 133 | + | 
 | 134 | +    return intensity_dic  | 
 | 135 | + | 
 | 136 | + | 
 | 137 | +def evaluate_marker_annotation(  | 
 | 138 | +    cochleae: List[str],  | 
 | 139 | +    output_dir: str,  | 
 | 140 | +    annotation_dirs: Optional[List[str]] = None,  | 
 | 141 | +    seg_name: str = "SGN_v2",  | 
 | 142 | +    marker_name: str = "Calb1",  | 
 | 143 | +    threshold_save_dir: Optional[str] = None,  | 
 | 144 | +    force: bool = False,  | 
 | 145 | +) -> None:  | 
 | 146 | +    """Evaluate marker annotations of a single or multiple annotators.  | 
 | 147 | +    Segmentation instances are assigned a positive (1) or negative label (2)  | 
 | 148 | +    in form of the "marker_label" component of the output segmentation table.  | 
 | 149 | +    The assignment is based on the median intensity supplied by a measurement table.  | 
 | 150 | +    Instances not considered for the assignment are labeled as 0.  | 
 | 151 | +
  | 
 | 152 | +    Args:  | 
 | 153 | +        cochleae: List of cochlea  | 
 | 154 | +        output_dir: Output directory for segmentation table with "marker_label" in format <cochlea>_<marker>_<seg>.tsv  | 
 | 155 | +        annotation_dirs: List of directories containing marker annotations by annotator(s).  | 
 | 156 | +        seg_name: Identifier for segmentation.  | 
 | 157 | +        marker_name: Identifier for marker stain.  | 
 | 158 | +        threshold_save_dir: Optional directory for saving the thresholds.  | 
 | 159 | +        force: Whether to overwrite already existing results.  | 
 | 160 | +    """  | 
 | 161 | +    input_key = "s0"  | 
 | 162 | + | 
 | 163 | +    if annotation_dirs is None:  | 
 | 164 | +        marker_dir = MARKER_DIR_SUBTYPE  | 
 | 165 | +        annotation_dirs = [entry.path for entry in os.scandir(marker_dir)  | 
 | 166 | +                           if os.path.isdir(entry) and "Result" in entry.name]  | 
 | 167 | + | 
 | 168 | +    for cochlea in cochleae:  | 
 | 169 | +        data_name = COCHLEAE[cochlea]["seg_data"]  | 
 | 170 | +        if "output_seg" in list(COCHLEAE[cochlea].keys()):  | 
 | 171 | +            output_seg = COCHLEAE[cochlea]["output_seg"]  | 
 | 172 | +        else:  | 
 | 173 | +            output_seg = data_name  | 
 | 174 | + | 
 | 175 | +        seg_string = "-".join(output_seg.split("_"))  | 
 | 176 | +        cochlea_str = "-".join(cochlea.split("_"))  | 
 | 177 | +        subtypes = COCHLEAE[cochlea]["subtype"]  | 
 | 178 | +        subtype_str = "_".join(subtypes)  | 
 | 179 | +        out_path = os.path.join(output_dir, f"{cochlea_str}_{subtype_str}_{seg_string}.tsv")  | 
 | 180 | +        if os.path.exists(out_path) and not force:  | 
 | 181 | +            continue  | 
 | 182 | + | 
 | 183 | +        # Get the segmentation data and table.  | 
 | 184 | +        input_path = f"{cochlea}/images/ome-zarr/{data_name}.ome.zarr"  | 
 | 185 | +        input_path, fs = get_s3_path(input_path)  | 
 | 186 | +        data_seg = read_image_data(input_path, input_key)  | 
 | 187 | + | 
 | 188 | +        table_seg_path = f"{cochlea}/tables/{output_seg}/default.tsv"  | 
 | 189 | +        table_path_s3, fs = get_s3_path(table_seg_path)  | 
 | 190 | +        with fs.open(table_path_s3, "r") as f:  | 
 | 191 | +            table_seg = pd.read_csv(f, sep="\t")  | 
 | 192 | + | 
 | 193 | +        table_measurement_path = f"{cochlea}/tables/{data_name}/subtype_ratio.tsv"  | 
 | 194 | + | 
 | 195 | +        # iterate through subtypes  | 
 | 196 | +        for subtype in subtypes:  | 
 | 197 | +            column = f"{subtype}_ratio_PV"  | 
 | 198 | +            table_path_s3, fs = get_s3_path(table_measurement_path)  | 
 | 199 | +            with fs.open(table_path_s3, "r") as f:  | 
 | 200 | +                table_measurement = pd.read_csv(f, sep="\t")  | 
 | 201 | + | 
 | 202 | +            cochlea_annotations = [a for a in annotation_dirs  | 
 | 203 | +                                   if len(find_annotations(a, cochlea, subtype)["center_strings"]) != 0]  | 
 | 204 | +            print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.")  | 
 | 205 | + | 
 | 206 | +            # Find the threholds from the annotated blocks and save it if specified.  | 
 | 207 | +            intensity_dic = find_thresholds(cochlea_annotations, cochlea, data_seg,  | 
 | 208 | +                                            table_measurement, column=column, pattern=subtype)  | 
 | 209 | +            if threshold_save_dir is not None:  | 
 | 210 | +                os.makedirs(threshold_save_dir, exist_ok=True)  | 
 | 211 | +                threshold_out_path = os.path.join(threshold_save_dir, f"{cochlea_str}_{subtype}_{seg_string}.json")  | 
 | 212 | +                with open(threshold_out_path, "w") as f:  | 
 | 213 | +                    json.dump(intensity_dic, f, sort_keys=True, indent=4)  | 
 | 214 | + | 
 | 215 | +            # Apply the threshold to all SGNs.  | 
 | 216 | +            table_seg = apply_nearest_threshold(  | 
 | 217 | +                intensity_dic, table_seg, table_measurement, column=column, suffix=subtype,  | 
 | 218 | +            )  | 
 | 219 | + | 
 | 220 | +        # Save the table with positives / negatives for all SGNs.  | 
 | 221 | +        os.makedirs(output_dir, exist_ok=True)  | 
 | 222 | +        table_seg.to_csv(out_path, sep="\t", index=False)  | 
 | 223 | + | 
 | 224 | + | 
 | 225 | +def main():  | 
 | 226 | +    parser = argparse.ArgumentParser(  | 
 | 227 | +        description="Assign each segmentation instance a marker based on annotation thresholds."  | 
 | 228 | +    )  | 
 | 229 | + | 
 | 230 | +    parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=COCHLEAE, help="Cochlea(e) to process.")  | 
 | 231 | +    parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.")  | 
 | 232 | +    parser.add_argument("-a", "--annotation_dirs", type=str, nargs="+", default=None,  | 
 | 233 | +                        help="Directories containing marker annotations.")  | 
 | 234 | +    parser.add_argument("--threshold_save_dir", "-t")  | 
 | 235 | +    parser.add_argument("-f", "--force", action="store_true")  | 
 | 236 | + | 
 | 237 | +    args = parser.parse_args()  | 
 | 238 | +    evaluate_marker_annotation(  | 
 | 239 | +        args.cochlea, args.output, args.annotation_dirs, threshold_save_dir=args.threshold_save_dir, force=args.force,  | 
 | 240 | +    )  | 
 | 241 | + | 
 | 242 | + | 
 | 243 | +if __name__ == "__main__":  | 
 | 244 | +    main()  | 
0 commit comments