| 
 | 1 | +import argparse  | 
 | 2 | +import os  | 
 | 3 | + | 
 | 4 | +import numpy as np  | 
 | 5 | +import pandas as pd  | 
 | 6 | +import zarr  | 
 | 7 | +from elf.io import open_file  | 
 | 8 | +import scipy.ndimage as ndimage  | 
 | 9 | + | 
 | 10 | +from flamingo_tools.s3_utils import get_s3_path  | 
 | 11 | +from flamingo_tools.segmentation.marker_extension import distance_based_marker_extension  | 
 | 12 | +from flamingo_tools.file_utils import read_image_data  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +def main():  | 
 | 16 | +    parser = argparse.ArgumentParser(  | 
 | 17 | +        description="Script for the extension of an SGN detection. "  | 
 | 18 | +        "Either locally or on an S3 bucket.")  | 
 | 19 | + | 
 | 20 | +    parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.")  | 
 | 21 | +    parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.")  | 
 | 22 | +    parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.")  | 
 | 23 | +    parser.add_argument("--input", default=None, help="Input tif.")  | 
 | 24 | + | 
 | 25 | +    parser.add_argument("--component_labels", type=int, nargs="+", default=[1],  | 
 | 26 | +                        help="Component labels of SGN_detect.")  | 
 | 27 | +    parser.add_argument("-d", "--extension_distance", type=float, default=8, help="Extension distance.")  | 
 | 28 | +    parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],  | 
 | 29 | +                        help="Resolution of input in micrometer.")  | 
 | 30 | + | 
 | 31 | +    args = parser.parse_args()  | 
 | 32 | + | 
 | 33 | +    block_shape = (128, 128, 128)  | 
 | 34 | +    chunks = (128, 128, 128)  | 
 | 35 | + | 
 | 36 | +    if len(args.resolution) == 1:  | 
 | 37 | +        resolution = tuple(args.resolution, args.resolution, args.resolution)  | 
 | 38 | +    else:  | 
 | 39 | +        resolution = tuple(args.resolution)  | 
 | 40 | + | 
 | 41 | +    if args.input is not None:  | 
 | 42 | +        data = read_image_data(args.input, None)  | 
 | 43 | +        shape = data.shape  | 
 | 44 | +        # Compute centers of mass for each label (excluding background = 0)  | 
 | 45 | +        markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0]))  | 
 | 46 | +        markers = np.array(markers)  | 
 | 47 | + | 
 | 48 | +    else:  | 
 | 49 | + | 
 | 50 | +        s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv")  | 
 | 51 | +        tsv_path, fs = get_s3_path(s3_path)  | 
 | 52 | +        with fs.open(tsv_path, 'r') as f:  | 
 | 53 | +            table = pd.read_csv(f, sep="\t")  | 
 | 54 | + | 
 | 55 | +        table = table.loc[table["component_labels"].isin(args.component_labels)]  | 
 | 56 | +        markers = list(zip(table["anchor_x"] / resolution[0],  | 
 | 57 | +                           table["anchor_y"] / resolution[1],  | 
 | 58 | +                           table["anchor_z"] / resolution[2]))  | 
 | 59 | +        markers = np.array(markers)  | 
 | 60 | + | 
 | 61 | +        s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr")  | 
 | 62 | +        input_key = "s0"  | 
 | 63 | +        s3_store, fs = get_s3_path(s3_path)  | 
 | 64 | +        with zarr.open(s3_store, mode="r") as f:  | 
 | 65 | +            data = f[input_key][:].astype("float32")  | 
 | 66 | + | 
 | 67 | +        shape = data.shape  | 
 | 68 | + | 
 | 69 | +    output_key = "extended_segmentation"  | 
 | 70 | +    output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr")  | 
 | 71 | + | 
 | 72 | +    output = open_file(output_path, mode="a")  | 
 | 73 | +    output_dataset = output.create_dataset(  | 
 | 74 | +        output_key, shape=shape, dtype=data.dtype,  | 
 | 75 | +        chunks=chunks, compression="gzip"  | 
 | 76 | +    )  | 
 | 77 | + | 
 | 78 | +    distance_based_marker_extension(  | 
 | 79 | +        markers=markers,  | 
 | 80 | +        output=output_dataset,  | 
 | 81 | +        extension_distance=args.extension_distance,  | 
 | 82 | +        sampling=resolution,  | 
 | 83 | +        block_shape=block_shape,  | 
 | 84 | +        n_threads=16,  | 
 | 85 | +    )  | 
 | 86 | + | 
 | 87 | + | 
 | 88 | +if __name__ == "__main__":  | 
 | 89 | +    main()  | 
0 commit comments