| 
 | 1 | +import argparse  | 
 | 2 | +import os  | 
 | 3 | +from shutil import rmtree  | 
 | 4 | + | 
 | 5 | +import pybdv.metadata as bdv_metadata  | 
 | 6 | +import torch  | 
 | 7 | +import z5py  | 
 | 8 | + | 
 | 9 | +from flamingo_tools.segmentation import run_unet_prediction, filter_isolated_objects  | 
 | 10 | +from flamingo_tools.mobie import add_raw_to_mobie, add_segmentation_to_mobie  | 
 | 11 | + | 
 | 12 | +MOBIE_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie"  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +def postprocess_seg(output_folder):  | 
 | 16 | +    print("Run segmentation postprocessing ...")  | 
 | 17 | +    seg_path = os.path.join(output_folder, "segmentation.zarr")  | 
 | 18 | +    seg_key = "segmentation"  | 
 | 19 | + | 
 | 20 | +    with z5py.File(seg_path, "r") as f:  | 
 | 21 | +        segmentation = f[seg_key][:]  | 
 | 22 | + | 
 | 23 | +    seg_filtered, n_pre, n_post = filter_isolated_objects(segmentation)  | 
 | 24 | + | 
 | 25 | +    with z5py.File(seg_path, "a") as f:  | 
 | 26 | +        chunks = f[seg_key].chunks  | 
 | 27 | +        f.create_dataset(  | 
 | 28 | +            "segmentation_postprocessed", data=seg_filtered, compression="gzip",  | 
 | 29 | +            chunks=chunks, dtype=seg_filtered.dtype  | 
 | 30 | +        )  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +def export_to_mobie(xml_path, segmentation_folder, scale, mobie_dataset, chunks):  | 
 | 34 | +    # Add to mobie:  | 
 | 35 | + | 
 | 36 | +    # - raw data (if not yet present)  | 
 | 37 | +    add_raw_to_mobie(  | 
 | 38 | +        mobie_project=MOBIE_ROOT,  | 
 | 39 | +        mobie_dataset=mobie_dataset,  | 
 | 40 | +        source_name="pv-channel",  | 
 | 41 | +        xml_path=xml_path,  | 
 | 42 | +        setup_id=0,  | 
 | 43 | +    )  | 
 | 44 | + | 
 | 45 | +    # TODO enable passing extra channel names  | 
 | 46 | +    # - additional channels  | 
 | 47 | +    setup_ids = bdv_metadata.get_setup_ids(xml_path)  | 
 | 48 | +    if len(setup_ids) > 1:  | 
 | 49 | +        extra_channel_names = ["gfp_channel", "myo_channel"]  | 
 | 50 | +        for i, setup_id in enumerate(setup_ids[1:]):  | 
 | 51 | +            add_raw_to_mobie(  | 
 | 52 | +                mobie_project=MOBIE_ROOT,  | 
 | 53 | +                mobie_dataset=mobie_dataset,  | 
 | 54 | +                source_name=extra_channel_names[i],  | 
 | 55 | +                xml_path=xml_path,  | 
 | 56 | +                setup_id=setup_id  | 
 | 57 | +            )  | 
 | 58 | + | 
 | 59 | +    # - segmentation and post-processed segmentation  | 
 | 60 | +    seg_path = os.path.join(segmentation_folder, "segmentation.zarr")  | 
 | 61 | +    seg_resolution = bdv_metadata.get_resolution(xml_path, setup_id=0)  | 
 | 62 | +    if scale == 1:  | 
 | 63 | +        seg_resolution = [2 * res for res in seg_resolution]  | 
 | 64 | +    unit = bdv_metadata.get_unit(xml_path, setup_id=0)  | 
 | 65 | + | 
 | 66 | +    seg_key = "segmentation"  | 
 | 67 | +    seg_name = "nuclei_fullscale" if scale == 0 else "nuclei_downscaled"  | 
 | 68 | +    add_segmentation_to_mobie(  | 
 | 69 | +        mobie_project=MOBIE_ROOT,  | 
 | 70 | +        mobie_dataset=mobie_dataset,  | 
 | 71 | +        source_name=seg_name,  | 
 | 72 | +        segmentation_path=seg_path,  | 
 | 73 | +        segmentation_key=seg_key,  | 
 | 74 | +        resolution=seg_resolution,  | 
 | 75 | +        unit=unit,  | 
 | 76 | +        scale_factors=4*[[2, 2, 2]],  | 
 | 77 | +        chunks=chunks,  | 
 | 78 | +    )  | 
 | 79 | + | 
 | 80 | +    seg_key = "segmentation_postprocessed"  | 
 | 81 | +    seg_name += "_postprocessed"  | 
 | 82 | +    add_segmentation_to_mobie(  | 
 | 83 | +        mobie_project=MOBIE_ROOT,  | 
 | 84 | +        mobie_dataset=mobie_dataset,  | 
 | 85 | +        source_name=seg_name,  | 
 | 86 | +        segmentation_path=seg_path,  | 
 | 87 | +        segmentation_key=seg_key,  | 
 | 88 | +        resolution=seg_resolution,  | 
 | 89 | +        unit=unit,  | 
 | 90 | +        scale_factors=4*[[2, 2, 2]],  | 
 | 91 | +        chunks=chunks,  | 
 | 92 | +    )  | 
 | 93 | + | 
 | 94 | + | 
 | 95 | +def main():  | 
 | 96 | +    parser = argparse.ArgumentParser()  | 
 | 97 | +    parser.add_argument("-i", "--input", required=True)  | 
 | 98 | +    parser.add_argument("-o", "--output_folder", required=True)  | 
 | 99 | +    parser.add_argument("-s", "--scale", required=True, type=int)  | 
 | 100 | +    parser.add_argument("-m", "--mobie_dataset", required=True)  | 
 | 101 | +    parser.add_argument("--model")  | 
 | 102 | + | 
 | 103 | +    args = parser.parse_args()  | 
 | 104 | + | 
 | 105 | +    scale = args.scale  | 
 | 106 | +    if scale == 0:  | 
 | 107 | +        min_size = 1000  | 
 | 108 | +    elif scale == 1:  | 
 | 109 | +        min_size = 250  | 
 | 110 | +    else:  | 
 | 111 | +        raise ValueError  | 
 | 112 | + | 
 | 113 | +    xml_path = args.input  | 
 | 114 | +    assert os.path.splitext(xml_path)[1] == ".xml"  | 
 | 115 | +    input_path = bdv_metadata.get_data_path(xml_path, return_absolute_path=True)  | 
 | 116 | + | 
 | 117 | +    # TODO need to make sure that PV is always setup 0  | 
 | 118 | +    input_key = f"setup0/timepoint0/s{scale}"  | 
 | 119 | + | 
 | 120 | +    have_cuda = torch.cuda.is_available()  | 
 | 121 | +    chunks = z5py.File(input_path, "r")[input_key].chunks  | 
 | 122 | +    block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks)  | 
 | 123 | +    halo = (16, 64, 64) if have_cuda else (8, 32, 32)  | 
 | 124 | + | 
 | 125 | +    if args.model is not None:  | 
 | 126 | +        model = args.model  | 
 | 127 | +    else:  | 
 | 128 | +        if scale == 0:  | 
 | 129 | +            model = "../training/checkpoints/cochlea_distance_unet"  | 
 | 130 | +        else:  | 
 | 131 | +            model = "../training/checkpoints/cochlea_distance_unet-train-downsampled"  | 
 | 132 | + | 
 | 133 | +    run_unet_prediction(  | 
 | 134 | +        input_path, input_key, args.output_folder, model,  | 
 | 135 | +        scale=None, min_size=min_size,  | 
 | 136 | +        block_shape=block_shape, halo=halo,  | 
 | 137 | +    )  | 
 | 138 | + | 
 | 139 | +    postprocess_seg(args.output_folder)  | 
 | 140 | + | 
 | 141 | +    export_to_mobie(xml_path, args.output_folder, scale, args.mobie_dataset, chunks)  | 
 | 142 | + | 
 | 143 | +    # clean up: remove segmentation folders  | 
 | 144 | +    print("Cleaning up intermediate segmentation results")  | 
 | 145 | +    print("This may take a while, but everything else is done.")  | 
 | 146 | +    print("You can check the results in the MoBIE project already at:")  | 
 | 147 | +    print(f"{MOBIE_ROOT}:{args.mobie_dataset}")  | 
 | 148 | +    rmtree(args.output_folder)  | 
 | 149 | + | 
 | 150 | + | 
 | 151 | +if __name__ == "__main__":  | 
 | 152 | +    main()  | 
0 commit comments