diff --git a/environment.yaml b/environment.yaml index 7ee7eba..0962641 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,6 +6,7 @@ channels: dependencies: - cluster_tools - scikit-image + - pooch - pybdv - pytorch - s3fs diff --git a/flamingo_tools/segmentation/sgn_detection.py b/flamingo_tools/segmentation/sgn_detection.py new file mode 100644 index 0000000..c15a30e --- /dev/null +++ b/flamingo_tools/segmentation/sgn_detection.py @@ -0,0 +1,180 @@ +import multiprocessing +import os +import threading +from concurrent import futures +from threadpoolctl import threadpool_limits +from typing import Optional, Tuple, Union + +import numpy as np +from numpy.typing import ArrayLike +import pandas as pd +from scipy.ndimage import distance_transform_edt +from skimage.segmentation import watershed +import zarr + +from elf.io import open_file +from elf.parallel.local_maxima import find_local_maxima +from flamingo_tools.segmentation.unet_prediction import prediction_impl +from tqdm import tqdm + +from elf.parallel.common import get_blocking + + +def distance_based_marker_extension( + markers: np.ndarray, + output: ArrayLike, + extension_distance: float, + sampling: Union[float, Tuple[float, ...]], + block_shape: Tuple[int, ...], + n_threads: Optional[int] = None, + verbose: bool = False, + roi: Optional[Tuple[slice, ...]] = None, +): + """ + Extend SGN detection to emulate shape of SGNs for better visualization. + + Args: + markers: Array of coordinates for seeding watershed. + output: Output for watershed. + extension_distance: Distance in micrometer for extension. + sampling: Resolution in micrometer. + block_shape: + n_threads: + verbose: + roi: + """ + n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads + blocking = get_blocking(output, block_shape, roi, n_threads) + + lock = threading.Lock() + + # determine the correct halo in pixels based on the sampling and the extension distance. + halo = [round(extension_distance / s) + 2 for s in sampling] + + @threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription + def extend_block(block_id): + block = blocking.getBlockWithHalo(block_id, halo) + outer_block = block.outerBlock + inner_block = block.innerBlock + + # get the indices and coordinates of the markers in the INNER block + mask = ( + (inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) & + (inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) & + (inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2]) + ) + markers_in_block_ids = np.where(mask)[0] + markers_in_block_coords = markers[markers_in_block_ids] + + # proceed if detections fall within inner block + if len(markers_in_block_coords) > 0: + markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords] + markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords] + + markers_in_block_coords = np.array(markers_in_block_coords, dtype=int) + z, y, x = markers_in_block_coords.T + + # Shift index by one so that zero is reserved for background id + markers_in_block_ids += 1 + + # Create the seed volume. + outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end)) + seeds = np.zeros(outer_block_shape, dtype="uint32") + seeds[z, y, x] = markers_in_block_ids + + # Compute the distance map. + distance = distance_transform_edt(seeds == 0, sampling=sampling) + + # And extend the seeds + mask = distance < extension_distance + segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask) + + # Write the segmentation. Note: we need to lock here because we write outside of our inner block + bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end)) + with lock: + this_output = output[bb] + this_output[mask] = segmentation[mask] + output[bb] = this_output + + n_blocks = blocking.numberOfBlocks + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose + )) + + +def sgn_detection( + input_path: str, + input_key: str, + output_folder: str, + model_path: str, + extension_distance: float, + sampling: Union[float, Tuple[float, ...]], + block_shape: Optional[Tuple[int, int, int]] = None, + halo: Optional[Tuple[int, int, int]] = None, + n_threads: Optional[int] = None, +): + """Run prediction for SGN detection. + + Args: + input_path: Input path to image channel for SGN detection. + input_key: Input key for resolution of image channel and mask channel. + output_folder: Output folder for SGN segmentation. + model_path: Path to model for SGN detection. + block_shape: The block-shape for running the prediction. + halo: The halo (= block overlap) to use for prediction. + spot_radius: Radius in pixel to convert spot detection of SGNs into a volume. + """ + if block_shape is None: + block_shape = (12, 128, 128) + if halo is None: + halo = (10, 64, 64) + + # Skip existing prediction, which is saved in output_folder/predictions.zarr + skip_prediction = False + output_path = os.path.join(output_folder, "predictions.zarr") + prediction_key = "prediction" + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): + skip_prediction = True + + if not skip_prediction: + prediction_impl( + input_path, input_key, output_folder, model_path, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + detection_path = os.path.join(output_folder, "SGN_detection.tsv") + if not os.path.exists(detection_path): + input_ = zarr.open(output_path, "r")[prediction_key] + detections_maxima = find_local_maxima( + input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16, + ) + + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections_maxima) + 1)[:, None], detections_maxima[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + detections.to_csv(detection_path, index=False, sep="\t") + + # extend detection + shape = input_.shape + chunks = (128, 128, 128) + segmentation_path = os.path.join(output_folder, "segmentation.zarr") + output = open_file(segmentation_path, mode="a") + segmentation_key = "segmentation" + output_dataset = output.create_dataset( + segmentation_key, shape=shape, dtype=np.uint64, + chunks=chunks, compression="gzip" + ) + + distance_based_marker_extension( + markers=detections_maxima, + output=output_dataset, + extension_distance=extension_distance, + sampling=sampling, + block_shape=(128, 128, 128), + n_threads=n_threads, + ) + diff --git a/reproducibility/templates_processing/detect_sgn_template.sbatch b/reproducibility/templates_processing/detect_sgn_template.sbatch new file mode 100644 index 0000000..f224826 --- /dev/null +++ b/reproducibility/templates_processing/detect_sgn_template.sbatch @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH --job-name=synapse-detect +#SBATCH -t 03:00:00 # estimated time, adapt to your needs +#SBATCH --mail-type=FAIL # send mail when job begins and ends + +#SBATCH -p grete:shared # the partition +#SBATCH -G A100:1 # For requesting 1 A100 GPU. +#SBATCH -A nim00007 +#SBATCH -c 4 +#SBATCH --mem 32G + +source ~/.bashrc +# micromamba activate micro-sam_gpu +micromamba activate sam + +# Print out some info. +echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" +echo "Home directory: ${HOME}" +echo "Working directory: $PWD" +echo "Current node: ${SLURM_NODELIST}" + +# Run the script +#python myprogram.py $SLURM_ARRAY_TASK_ID + +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools +cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit + +export SCRIPT_DIR=$SCRIPT_REPO/scripts + +# name of cochlea, as it appears in MoBIE and the NHR +COCHLEA=$1 +# model of SGN detection, e.g. v5b +MODEL_VERSION=$2 + +# data on NHR +MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/ +export INPUT_PATH="$MOBIE_DIR"/"$COCHLEA"/images/ome-zarr/PV.ome.zarr + +# data on MoBIE +# export INPUT_PATH="$COCHLEA"/images/ome-zarr/PV.ome.zarr +# use --s3 flag for script + +export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/SGN_detect-"$MODEL_VERSION" + +if ! [[ -f $OUTPUT_FOLDER ]] ; then + mkdir -p "$OUTPUT_FOLDER" +fi + +export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-"$MODEL_VERSION".pt +INPUT_KEY="s0" + +echo "OUTPUT_FOLDER $OUTPUT_FOLDER" +echo "MODEL $MODEL" + +python ~/flamingo-tools/scripts/sgn_detection/sgn_detection.py \ + --input "$INPUT_PATH" \ + --input_key $INPUT_KEY \ + --output_folder "$OUTPUT_FOLDER" \ + --model "$MODEL" diff --git a/scripts/sgn_detection/sgn_detection.py b/scripts/sgn_detection/sgn_detection.py new file mode 100644 index 0000000..2d31da7 --- /dev/null +++ b/scripts/sgn_detection/sgn_detection.py @@ -0,0 +1,54 @@ +import argparse + +import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.segmentation.sgn_detection import sgn_detection + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True, help="Path to image data to be segmented.") + parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.") + parser.add_argument("-m", "--model", required=True, + help="Path to SGN detection model.") + parser.add_argument("-k", "--input_key", default=None, + help="The key / internal path to image data.") + + parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.") + parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779], + help="Resolution of input in micrometer.") + + parser.add_argument("--s3", action="store_true", help="Use S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + block_shape = (12, 128, 128) + halo = (10, 64, 64) + + if len(args.resolution) == 1: + resolution = tuple(args.resolution, args.resolution, args.resolution) + else: + resolution = tuple(args.resolution) + + if args.s3: + input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + + else: + input_path = args.input + + sgn_detection(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder, + model_path=args.model, block_shape=block_shape, halo=halo, + extension_distance=args.extension_distance, sampling=resolution) + + +if __name__ == "__main__": + main() diff --git a/scripts/sgn_detection/sgn_marker_extension.py b/scripts/sgn_detection/sgn_marker_extension.py new file mode 100644 index 0000000..3a702c4 --- /dev/null +++ b/scripts/sgn_detection/sgn_marker_extension.py @@ -0,0 +1,89 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import zarr +from elf.io import open_file +import scipy.ndimage as ndimage + +from flamingo_tools.s3_utils import get_s3_path +from flamingo_tools.segmentation.sgn_detection import distance_based_marker_extension +from flamingo_tools.file_utils import read_image_data + + +def main(): + parser = argparse.ArgumentParser( + description="Script for the extension of an SGN detection. " + "Either locally or on an S3 bucket.") + + parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.") + parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.") + parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.") + parser.add_argument("--input", default=None, help="Input tif.") + + parser.add_argument("--component_labels", type=int, nargs="+", default=[1], + help="Component labels of SGN_detect.") + parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.") + parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779], + help="Resolution of input in micrometer.") + + args = parser.parse_args() + + block_shape = (128, 128, 128) + chunks = (128, 128, 128) + + if len(args.resolution) == 1: + resolution = tuple(args.resolution, args.resolution, args.resolution) + else: + resolution = tuple(args.resolution) + + if args.input is not None: + data = read_image_data(args.input, None) + shape = data.shape + # Compute centers of mass for each label (excluding background = 0) + markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0])) + markers = np.array(markers) + + else: + + s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv") + tsv_path, fs = get_s3_path(s3_path) + with fs.open(tsv_path, 'r') as f: + table = pd.read_csv(f, sep="\t") + + table = table.loc[table["component_labels"].isin(args.component_labels)] + markers = list(zip(table["anchor_x"] / resolution[0], + table["anchor_y"] / resolution[1], + table["anchor_z"] / resolution[2])) + markers = np.array(markers) + + s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr") + input_key = "s0" + s3_store, fs = get_s3_path(s3_path) + with zarr.open(s3_store, mode="r") as f: + data = f[input_key][:].astype("float32") + + shape = data.shape + + output_key = "extended_segmentation" + output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr") + + output = open_file(output_path, mode="a") + output_dataset = output.create_dataset( + output_key, shape=shape, dtype=np.dtype("uint32"), + chunks=chunks, compression="gzip" + ) + + distance_based_marker_extension( + markers=markers, + output=output_dataset, + extension_distance=args.extension_distance, + sampling=resolution, + block_shape=block_shape, + n_threads=16, + ) + + +if __name__ == "__main__": + main()