Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
dependencies:
- cluster_tools
- scikit-image
- pooch
- pybdv
- pytorch
- s3fs
Expand Down
180 changes: 180 additions & 0 deletions flamingo_tools/segmentation/sgn_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import multiprocessing
import os
import threading
from concurrent import futures
from threadpoolctl import threadpool_limits
from typing import Optional, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike
import pandas as pd
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import watershed
import zarr

from elf.io import open_file
from elf.parallel.local_maxima import find_local_maxima
from flamingo_tools.segmentation.unet_prediction import prediction_impl
from tqdm import tqdm

from elf.parallel.common import get_blocking


def distance_based_marker_extension(
markers: np.ndarray,
output: ArrayLike,
extension_distance: float,
sampling: Union[float, Tuple[float, ...]],
block_shape: Tuple[int, ...],
n_threads: Optional[int] = None,
verbose: bool = False,
roi: Optional[Tuple[slice, ...]] = None,
):
"""
Extend SGN detection to emulate shape of SGNs for better visualization.

Args:
markers: Array of coordinates for seeding watershed.
output: Output for watershed.
extension_distance: Distance in micrometer for extension.
sampling: Resolution in micrometer.
block_shape:
n_threads:
verbose:
roi:
"""
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
blocking = get_blocking(output, block_shape, roi, n_threads)

lock = threading.Lock()

# determine the correct halo in pixels based on the sampling and the extension distance.
halo = [round(extension_distance / s) + 2 for s in sampling]

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def extend_block(block_id):
block = blocking.getBlockWithHalo(block_id, halo)
outer_block = block.outerBlock
inner_block = block.innerBlock

# get the indices and coordinates of the markers in the INNER block
mask = (
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
)
markers_in_block_ids = np.where(mask)[0]
markers_in_block_coords = markers[markers_in_block_ids]

# proceed if detections fall within inner block
if len(markers_in_block_coords) > 0:
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]

markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
z, y, x = markers_in_block_coords.T

# Shift index by one so that zero is reserved for background id
markers_in_block_ids += 1

# Create the seed volume.
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
seeds = np.zeros(outer_block_shape, dtype="uint32")
seeds[z, y, x] = markers_in_block_ids

# Compute the distance map.
distance = distance_transform_edt(seeds == 0, sampling=sampling)

# And extend the seeds
mask = distance < extension_distance
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)

# Write the segmentation. Note: we need to lock here because we write outside of our inner block
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
with lock:
this_output = output[bb]
this_output[mask] = segmentation[mask]
output[bb] = this_output

n_blocks = blocking.numberOfBlocks
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
))


def sgn_detection(
input_path: str,
input_key: str,
output_folder: str,
model_path: str,
extension_distance: float,
sampling: Union[float, Tuple[float, ...]],
block_shape: Optional[Tuple[int, int, int]] = None,
halo: Optional[Tuple[int, int, int]] = None,
n_threads: Optional[int] = None,
):
"""Run prediction for SGN detection.

Args:
input_path: Input path to image channel for SGN detection.
input_key: Input key for resolution of image channel and mask channel.
output_folder: Output folder for SGN segmentation.
model_path: Path to model for SGN detection.
block_shape: The block-shape for running the prediction.
halo: The halo (= block overlap) to use for prediction.
spot_radius: Radius in pixel to convert spot detection of SGNs into a volume.
"""
if block_shape is None:
block_shape = (12, 128, 128)
if halo is None:
halo = (10, 64, 64)

# Skip existing prediction, which is saved in output_folder/predictions.zarr
skip_prediction = False
output_path = os.path.join(output_folder, "predictions.zarr")
prediction_key = "prediction"
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
skip_prediction = True

if not skip_prediction:
prediction_impl(
input_path, input_key, output_folder, model_path,
scale=None, block_shape=block_shape, halo=halo,
apply_postprocessing=False, output_channels=1,
)

detection_path = os.path.join(output_folder, "SGN_detection.tsv")
if not os.path.exists(detection_path):
input_ = zarr.open(output_path, "r")[prediction_key]
detections_maxima = find_local_maxima(
input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16,
)

# Save the result in mobie compatible format.
detections = np.concatenate(
[np.arange(1, len(detections_maxima) + 1)[:, None], detections_maxima[:, ::-1]], axis=1
)
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
detections.to_csv(detection_path, index=False, sep="\t")

# extend detection
shape = input_.shape
chunks = (128, 128, 128)
segmentation_path = os.path.join(output_folder, "segmentation.zarr")
output = open_file(segmentation_path, mode="a")
segmentation_key = "segmentation"
output_dataset = output.create_dataset(
segmentation_key, shape=shape, dtype=np.uint64,
chunks=chunks, compression="gzip"
)

distance_based_marker_extension(
markers=detections_maxima,
output=output_dataset,
extension_distance=extension_distance,
sampling=sampling,
block_shape=(128, 128, 128),
n_threads=n_threads,
)

60 changes: 60 additions & 0 deletions reproducibility/templates_processing/detect_sgn_template.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/bin/bash
#SBATCH --job-name=synapse-detect
#SBATCH -t 03:00:00 # estimated time, adapt to your needs
#SBATCH --mail-type=FAIL # send mail when job begins and ends

#SBATCH -p grete:shared # the partition
#SBATCH -G A100:1 # For requesting 1 A100 GPU.
#SBATCH -A nim00007
#SBATCH -c 4
#SBATCH --mem 32G

source ~/.bashrc
# micromamba activate micro-sam_gpu
micromamba activate sam

# Print out some info.
echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}"
echo "Home directory: ${HOME}"
echo "Working directory: $PWD"
echo "Current node: ${SLURM_NODELIST}"

# Run the script
#python myprogram.py $SLURM_ARRAY_TASK_ID

# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools
SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools
cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit

export SCRIPT_DIR=$SCRIPT_REPO/scripts

# name of cochlea, as it appears in MoBIE and the NHR
COCHLEA=$1
# model of SGN detection, e.g. v5b
MODEL_VERSION=$2

# data on NHR
MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/
export INPUT_PATH="$MOBIE_DIR"/"$COCHLEA"/images/ome-zarr/PV.ome.zarr

# data on MoBIE
# export INPUT_PATH="$COCHLEA"/images/ome-zarr/PV.ome.zarr
# use --s3 flag for script

export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/SGN_detect-"$MODEL_VERSION"

if ! [[ -f $OUTPUT_FOLDER ]] ; then
mkdir -p "$OUTPUT_FOLDER"
fi

export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-"$MODEL_VERSION".pt
INPUT_KEY="s0"

echo "OUTPUT_FOLDER $OUTPUT_FOLDER"
echo "MODEL $MODEL"

python ~/flamingo-tools/scripts/sgn_detection/sgn_detection.py \
--input "$INPUT_PATH" \
--input_key $INPUT_KEY \
--output_folder "$OUTPUT_FOLDER" \
--model "$MODEL"
54 changes: 54 additions & 0 deletions scripts/sgn_detection/sgn_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import argparse

import flamingo_tools.s3_utils as s3_utils
from flamingo_tools.segmentation.sgn_detection import sgn_detection


def main():

parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", required=True, help="Path to image data to be segmented.")
parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.")
parser.add_argument("-m", "--model", required=True,
help="Path to SGN detection model.")
parser.add_argument("-k", "--input_key", default=None,
help="The key / internal path to image data.")

parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.")
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
help="Resolution of input in micrometer.")

parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
parser.add_argument("--s3_credentials", type=str, default=None,
help="Input file containing S3 credentials. "
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
parser.add_argument("--s3_bucket_name", type=str, default=None,
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
parser.add_argument("--s3_service_endpoint", type=str, default=None,
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")

args = parser.parse_args()

block_shape = (12, 128, 128)
halo = (10, 64, 64)

if len(args.resolution) == 1:
resolution = tuple(args.resolution, args.resolution, args.resolution)
else:
resolution = tuple(args.resolution)

if args.s3:
input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
service_endpoint=args.s3_service_endpoint,
credential_file=args.s3_credentials)

else:
input_path = args.input

sgn_detection(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder,
model_path=args.model, block_shape=block_shape, halo=halo,
extension_distance=args.extension_distance, sampling=resolution)


if __name__ == "__main__":
main()
89 changes: 89 additions & 0 deletions scripts/sgn_detection/sgn_marker_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import os

import numpy as np
import pandas as pd
import zarr
from elf.io import open_file
import scipy.ndimage as ndimage

from flamingo_tools.s3_utils import get_s3_path
from flamingo_tools.segmentation.sgn_detection import distance_based_marker_extension
from flamingo_tools.file_utils import read_image_data


def main():
parser = argparse.ArgumentParser(
description="Script for the extension of an SGN detection. "
"Either locally or on an S3 bucket.")

parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.")
parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.")
parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.")
parser.add_argument("--input", default=None, help="Input tif.")

parser.add_argument("--component_labels", type=int, nargs="+", default=[1],
help="Component labels of SGN_detect.")
parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.")
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
help="Resolution of input in micrometer.")

args = parser.parse_args()

block_shape = (128, 128, 128)
chunks = (128, 128, 128)

if len(args.resolution) == 1:
resolution = tuple(args.resolution, args.resolution, args.resolution)
else:
resolution = tuple(args.resolution)

if args.input is not None:
data = read_image_data(args.input, None)
shape = data.shape
# Compute centers of mass for each label (excluding background = 0)
markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0]))
markers = np.array(markers)

else:

s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv")
tsv_path, fs = get_s3_path(s3_path)
with fs.open(tsv_path, 'r') as f:
table = pd.read_csv(f, sep="\t")

table = table.loc[table["component_labels"].isin(args.component_labels)]
markers = list(zip(table["anchor_x"] / resolution[0],
table["anchor_y"] / resolution[1],
table["anchor_z"] / resolution[2]))
markers = np.array(markers)

s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr")
input_key = "s0"
s3_store, fs = get_s3_path(s3_path)
with zarr.open(s3_store, mode="r") as f:
data = f[input_key][:].astype("float32")

shape = data.shape

output_key = "extended_segmentation"
output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr")

output = open_file(output_path, mode="a")
output_dataset = output.create_dataset(
output_key, shape=shape, dtype=np.dtype("uint32"),
chunks=chunks, compression="gzip"
)

distance_based_marker_extension(
markers=markers,
output=output_dataset,
extension_distance=args.extension_distance,
sampling=resolution,
block_shape=block_shape,
n_threads=16,
)


if __name__ == "__main__":
main()
Loading