Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 125 additions & 12 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import os
from concurrent import futures
from functools import partial
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import trimesh
from elf.io import open_file
from elf.wrapper.resized_volume import ResizedVolume
from nifty.tools import blocking
from skimage.measure import marching_cubes, regionprops_table
from scipy.ndimage import binary_dilation
from tqdm import tqdm

from .file_utils import read_image_data
Expand All @@ -29,9 +33,14 @@ def _measure_volume_and_surface(mask, resolution):
return volume, surface


def _get_bounding_box_and_center(table, seg_id, resolution, shape):
def _get_bounding_box_and_center(table, seg_id, resolution, shape, dilation):
row = table[table.label_id == seg_id]

if dilation is not None and dilation > 0:
bb_extension = dilation + 1
else:
bb_extension = 1

bb_min = np.array([
row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item()
]).astype("float32") / resolution
Expand All @@ -43,7 +52,7 @@ def _get_bounding_box_and_center(table, seg_id, resolution, shape):
bb_max = np.round(bb_max, 0).astype("int32")

bb = tuple(
slice(max(bmin - 1, 0), min(bmax + 1, sh))
slice(max(bmin - bb_extension, 0), min(bmax + bb_extension, sh))
for bmin, bmax, sh in zip(bb_min, bb_max, shape)
)

Expand Down Expand Up @@ -115,13 +124,15 @@ def _normalize_background(measures, image, mask, center, radius, norm, median_on

def _default_object_features(
seg_id, table, image, segmentation, resolution,
foreground_mask=None, background_radius=None, norm=np.divide, median_only=False,
background_mask=None, background_radius=None, norm=np.divide, median_only=False, dilation=None
):
bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)
bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation)

local_image = image[bb]
mask = segmentation[bb] == seg_id
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
if dilation is not None and dilation > 0:
mask = binary_dilation(mask, iterations=dilation)
masked_intensity = local_image[mask]

# Do the base intensity measurements.
Expand All @@ -141,7 +152,7 @@ def _default_object_features(
# The resolution is given in micrometer per pixel.
# So we have to divide by the resolution to obtain the radius in pixel.
radius_in_pixel = background_radius / resolution
measures = _normalize_background(measures, image, foreground_mask, center, radius_in_pixel, norm, median_only)
measures = _normalize_background(measures, image, background_mask, center, radius_in_pixel, norm, median_only)

# Do the volume and surface measurement.
if not median_only:
Expand All @@ -151,13 +162,15 @@ def _default_object_features(
return measures


def _regionprops_features(seg_id, table, image, segmentation, resolution, foreground_mask=None):
bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)
def _regionprops_features(seg_id, table, image, segmentation, resolution, background_mask=None, dilation=None):
bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation)

local_image = image[bb]
local_segmentation = segmentation[bb]
mask = local_segmentation == seg_id
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
if dilation is not None and dilation > 0:
mask = binary_dilation(mask, iterations=dilation)
local_segmentation[~mask] = 0

features = regionprops_table(
Expand Down Expand Up @@ -196,16 +209,16 @@ def _regionprops_features(seg_id, table, image, segmentation, resolution, foregr
"""


# TODO integrate segmentation post-processing, see `_extend_sgns_simple` in `gfp_annotation.py`
def compute_object_measures_impl(
image: np.typing.ArrayLike,
segmentation: np.typing.ArrayLike,
n_threads: Optional[int] = None,
resolution: float = 0.38,
table: Optional[pd.DataFrame] = None,
feature_set: str = "default",
foreground_mask: Optional[np.typing.ArrayLike] = None,
background_mask: Optional[np.typing.ArrayLike] = None,
median_only: bool = False,
dilation: Optional[int] = None,
) -> pd.DataFrame:
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.

Expand All @@ -218,8 +231,10 @@ def compute_object_measures_impl(
resolution: The resolution / voxel size of the data.
table: The segmentation table. Will be computed on the fly if it is not given.
feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
foreground_mask: An optional mask indicating the area to use for computing background correction values.
background_mask: An optional mask indicating the area to use for computing background correction values.
median_only: Whether to only compute the median intensity.
dilation: Value for dilating the segmentation before computing measurements.
By default no dilation is applied.

Returns:
The table with per object measurements.
Expand All @@ -235,8 +250,9 @@ def compute_object_measures_impl(
image=image,
segmentation=segmentation,
resolution=resolution,
foreground_mask=foreground_mask,
background_mask=background_mask,
median_only=median_only,
dilation=dilation,
)

seg_ids = table.label_id.values
Expand All @@ -246,6 +262,7 @@ def compute_object_measures_impl(

# For debugging.
# measure_function(seg_ids[0])
# breakpoint()

with futures.ThreadPoolExecutor(n_threads) as pool:
measures = list(tqdm(
Expand All @@ -272,6 +289,9 @@ def compute_object_measures(
feature_set: str = "default",
s3_flag: bool = False,
component_list: List[int] = [],
dilation: Optional[int] = None,
median_only: bool = False,
background_mask: Optional[np.typing.ArrayLike] = None,
) -> None:
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.

Expand All @@ -291,6 +311,12 @@ def compute_object_measures(
resolution: The resolution / voxel size of the data.
force: Whether to overwrite an existing output table.
feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
s3_flag:
component_list:
median_only: Whether to only compute the median intensity.
dilation: Value for dilating the segmentation before computing measurements.
By default no dilation is applied.
background_mask: An optional mask indicating the area to use for computing background correction values.
"""
if os.path.exists(output_table_path) and not force:
return
Expand All @@ -315,5 +341,92 @@ def compute_object_measures(

measures = compute_object_measures_impl(
image, segmentation, n_threads, resolution, table=table, feature_set=feature_set,
median_only=median_only, dilation=dilation, background_mask=background_mask,
)
measures.to_csv(output_table_path, sep="\t", index=False)


def compute_sgn_background_mask(
image_path: str,
segmentation_path: str,
image_key: Optional[str] = None,
segmentation_key: Optional[str] = None,
threshold_percentile: float = 35.0,
scale_factor: Tuple[int, int, int] = (16, 16, 16),
n_threads: Optional[int] = None,
cache_path: Optional[str] = None,
) -> np.typing.ArrayLike:
"""Compute the background mask for intensity measurements in the SGN segmentation.

This function computes a mask for determining the background signal in the rosenthal canal.
It is computed by downsampling the image (PV) and segmentation (SGNs) internally,
by thresholding the downsampled image, and by then intersecting this mask with the segmentation.
This results in a mask that is positive for the background signal within the rosenthal canal.

Args:
image_path: The path to the image data with the PV channel.
segmentation_path: The path to the SGN segmentation.
image_key: Internal path for the image data, for zarr or similar file formats.
segmentation_key: Internal path for the segmentation data, for zarr or similar file formats.
threshold_percentile: The percentile threshold for separating foreground and background in the PV signal.
scale_factor: The scale factor for internally downsampling the mask.
n_threads: The number of threads for parallelizing the computation.
cache_path: Optional path to save the downscaled background mask to zarr.

Returns:
The mask for determining the background values.
"""
image = read_image_data(image_path, image_key)
segmentation = read_image_data(segmentation_path, segmentation_key)
assert image.shape == segmentation.shape

if cache_path is not None and os.path.exists(cache_path):
with open_file(cache_path, "r") as f:
if "mask" in f:
low_res_mask = f["mask"][:]
mask = ResizedVolume(low_res_mask, shape=image.shape, order=0)
return mask

original_shape = image.shape
downsampled_shape = tuple(int(np.round(sh / sf)) for sh, sf in zip(original_shape, scale_factor))

low_res_mask = np.zeros(downsampled_shape, dtype="bool")

# This corresponds to a block shape of 128 x 512 x 512 in the original resolution,
# which roughly corresponds to the size of the blocks we use for the GFP annotation.
chunk_shape = (8, 32, 32)

blocks = blocking((0, 0, 0), downsampled_shape, chunk_shape)
n_blocks = blocks.numberOfBlocks

img_resized = ResizedVolume(image, downsampled_shape)
seg_resized = ResizedVolume(segmentation, downsampled_shape, order=0)

def _compute_block(block_id):
block = blocks.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

img = img_resized[bb]
threshold = np.percentile(img, threshold_percentile)

this_mask = img > threshold
this_seg = seg_resized[bb] != 0
this_seg = binary_dilation(this_seg)
this_mask[this_seg] = 0

low_res_mask[bb] = this_mask

n_threads = mp.cpu_count() if n_threads is None else n_threads
randomized_blocks = np.arange(0, n_blocks)
np.random.shuffle(randomized_blocks)
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(
tp.map(_compute_block, randomized_blocks), total=n_blocks, desc="Compute background mask"
))

if cache_path is not None:
with open_file(cache_path, "a") as f:
f.create_dataset("mask", data=low_res_mask, chunks=(64, 64, 64))

mask = ResizedVolume(low_res_mask, shape=original_shape, order=0)
return mask
Loading
Loading