Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions flamingo_tools/segmentation/marker_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from threadpoolctl import threadpool_limits

import multiprocessing
from concurrent import futures
from threading import Lock
from typing import Optional, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike

from scipy.ndimage import distance_transform_edt
from skimage.segmentation import watershed
from tqdm import tqdm

from elf.parallel.common import get_blocking


def distance_based_marker_extension(
markers: np.ndarray,
output: ArrayLike,
extension_distance: float,
sampling: Union[float, Tuple[float, ...]],
block_shape: Tuple[int, ...],
n_threads: Optional[int] = None,
verbose: bool = False,
roi: Optional[Tuple[slice, ...]] = None,
):
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
blocking = get_blocking(output, block_shape, roi, n_threads)

lock = Lock()

# determine the correct halo in pixels based on the sampling and the extension distance.
halo = [round(extension_distance / s) + 2 for s in sampling]

@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
def extend_block(block_id):
block = blocking.getBlockWithHalo(block_id, halo)
outer_block = block.outerBlock
inner_block = block.innerBlock

# TODO get the indices and coordinates of the markers in the INNER block
# markers_in_block_ids = [int(i) for i in np.unique(inner_block)[1:]]
mask = (
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
)
markers_in_block_ids = np.where(mask)[0]
markers_in_block_coords = markers[markers_in_block_ids]

# TODO offset the marker coordinates with respect to the OUTER block
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]
markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
z, y, x = markers_in_block_coords.T

# Shift index by one so that zero is reserved for background id
markers_in_block_ids += 1

# Create the seed volume.
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
seeds = np.zeros(outer_block_shape, dtype="uint32")
seeds[z, y, x] = markers_in_block_ids

# Compute the distance map.
distance = distance_transform_edt(seeds == 0, sampling=sampling)

# And extend the seeds
mask = distance < extension_distance
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)

# Write the segmentation. Note: we need to lock here because we write outside of our inner block
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
with lock:
this_output = output[bb]
this_output[mask] = segmentation[mask]
output[bb] = this_output

n_blocks = blocking.numberOfBlocks
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
))
89 changes: 89 additions & 0 deletions scripts/prediction/sgn_marker_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import os

import numpy as np
import pandas as pd
import zarr
from elf.io import open_file
import scipy.ndimage as ndimage

from flamingo_tools.s3_utils import get_s3_path
from flamingo_tools.segmentation.marker_extension import distance_based_marker_extension
from flamingo_tools.file_utils import read_image_data


def main():
parser = argparse.ArgumentParser(
description="Script for the extension of an SGN detection. "
"Either locally or on an S3 bucket.")

parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.")
parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.")
parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.")
parser.add_argument("--input", default=None, help="Input tif.")

parser.add_argument("--component_labels", type=int, nargs="+", default=[1],
help="Component labels of SGN_detect.")
parser.add_argument("-d", "--extension_distance", type=float, default=8, help="Extension distance.")
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
help="Resolution of input in micrometer.")

args = parser.parse_args()

block_shape = (128, 128, 128)
chunks = (128, 128, 128)

if len(args.resolution) == 1:
resolution = tuple(args.resolution, args.resolution, args.resolution)
else:
resolution = tuple(args.resolution)

if args.input is not None:
data = read_image_data(args.input, None)
shape = data.shape
# Compute centers of mass for each label (excluding background = 0)
markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0]))
markers = np.array(markers)

else:

s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv")
tsv_path, fs = get_s3_path(s3_path)
with fs.open(tsv_path, 'r') as f:
table = pd.read_csv(f, sep="\t")

table = table.loc[table["component_labels"].isin(args.component_labels)]
markers = list(zip(table["anchor_x"] / resolution[0],
table["anchor_y"] / resolution[1],
table["anchor_z"] / resolution[2]))
markers = np.array(markers)

s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr")
input_key = "s0"
s3_store, fs = get_s3_path(s3_path)
with zarr.open(s3_store, mode="r") as f:
data = f[input_key][:].astype("float32")

shape = data.shape

output_key = "extended_segmentation"
output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr")

output = open_file(output_path, mode="a")
output_dataset = output.create_dataset(
output_key, shape=shape, dtype=data.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to set a different data type here. Setting it to the datatype of the input data doesn't make sense.
This is a segmentation, so it's best to set it to uint32.

chunks=chunks, compression="gzip"
)

distance_based_marker_extension(
markers=markers,
output=output_dataset,
extension_distance=args.extension_distance,
sampling=resolution,
block_shape=block_shape,
n_threads=16,
)


if __name__ == "__main__":
main()
Loading