Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- pytorch
- s3fs
- torch_em
- trimesh
- z5py
# Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet
- zarr <3
146 changes: 146 additions & 0 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import multiprocessing as mp
from concurrent import futures
from typing import Optional

import numpy as np
import pandas as pd
import trimesh
from skimage.measure import marching_cubes
from tqdm import tqdm

from .file_utils import read_image_data
from .segmentation.postprocessing import compute_table_on_the_fly


def _measure_volume_and_surface(mask, resolution):
# Use marching_cubes for 3D data
verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3)

mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
surface = mesh.area
if mesh.is_watertight:
volume = np.abs(mesh.volume)
else:
volume = np.nan

return volume, surface


def compute_object_measures_impl(
image: np.typing.ArrayLike,
segmentation: np.typing.ArrayLike,
n_threads: Optional[int] = None,
resolution: float = 0.38,
table: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
See `compute_object_measures` for details.
Args:
image: The image data.
segmentation: The segmentation.
n_threads: The number of threads to use for computation.
resolution: The resolution / voxel size of the data.
table: The segmentation table. Will be computed on the fly if it is not given.
Returns:
The table with per object measurements.
"""
if table is None:
table = compute_table_on_the_fly(segmentation, resolution=resolution)

def intensity_measures(seg_id):
# Get the bounding box.
row = table[table.label_id == seg_id]

bb_min = np.array([
row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item()
]).astype("float32") / resolution
bb_min = np.round(bb_min, 0).astype("int32")

bb_max = np.array([
row.bb_max_z.item(), row.bb_max_y.item(), row.bb_max_x.item()
]).astype("float32") / resolution
bb_max = np.round(bb_max, 0).astype("int32")

bb = tuple(
slice(max(bmin - 1, 0), min(bmax + 1, sh))
for bmin, bmax, sh in zip(bb_min, bb_max, image.shape)
)

local_image = image[bb]
mask = segmentation[bb] == seg_id
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
masked_intensity = local_image[mask]

# Do the base intensity measurements.
measures = {
"label_id": seg_id,
"mean": np.mean(masked_intensity),
"stdev": np.std(masked_intensity),
"min": np.min(masked_intensity),
"max": np.max(masked_intensity),
"median": np.median(masked_intensity),
}
for percentile in (5, 10, 25, 75, 90, 95):
measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile)

# Do the volume and surface measurement.
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
return measures

seg_ids = table.label_id.values
assert len(seg_ids) > 0, "The segmentation table is empty."
n_threads = mp.cpu_count() if n_threads is None else n_threads
with futures.ThreadPoolExecutor(n_threads) as pool:
measures = list(tqdm(
pool.map(intensity_measures, seg_ids), total=len(seg_ids), desc="Compute intensity measures"
))

# Create the result table and save it.
keys = measures[0].keys()
measures = pd.DataFrame({k: [measure[k] for measure in measures] for k in keys})
return measures


# Could also support s3 directly?
def compute_object_measures(
image_path: str,
segmentation_path: str,
segmentation_table_path: str,
output_table_path: str,
image_key: Optional[str] = None,
segmentation_key: Optional[str] = None,
n_threads: Optional[int] = None,
resolution: float = 0.38,
) -> None:
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
This computes the mean, standard deviation, minimum, maximum, median and
5th, 10th, 25th, 75th, 90th and 95th percentile of the intensity image
per cell, as well as the volume and surface.
Args:
image_path: The filepath to the image data. Either a tif or hdf5/zarr/n5 file.
segmentation_path: The filepath to the segmentation data. Either a tif or hdf5/zarr/n5 file.
segmentation_table_path: The path to the segmentation table in MoBIE format.
output_table_path: The path for saving the segmentation with intensity measures.
image_key: The key (= internal path) for the image data. Not needed fir tif.
segmentation_key: The key (= internal path) for the segmentation data. Not needed for tif.
n_threads: The number of threads to use for computation.
resolution: The resolution / voxel size of the data.
"""
# First, we load the pre-computed segmentation table from MoBIE.
table = pd.read_csv(segmentation_table_path, sep="\t")

# Then, open the volumes.
image = read_image_data(image_path, image_key)
segmentation = read_image_data(segmentation_path, segmentation_key)

measures = compute_object_measures_impl(
image, segmentation, n_threads, resolution, table=table
)
measures.to_csv(output_table_path, sep="\t", index=False)
35 changes: 27 additions & 8 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,39 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray:
#


def _compute_table(segmentation, resolution):
def compute_table_on_the_fly(segmentation: np.typing.ArrayLike, resolution: float) -> pd.DataFrame:
"""Compute a segmentation table compatible with MoBIE.

The table contains information about the number of pixels per object,
the anchor (= centroid) and the bounding box. Anchor and bounding box are given in physical coordinates.

Args:
segmentation: The segmentation for which to compute the table.
resolution: The physical voxel spacing of the data.

Returns:
The segmentation table.
"""
props = measure.regionprops(segmentation)
label_ids = np.array([prop.label for prop in props])
coordinates = np.array([prop.centroid for prop in props])
coordinates = np.array([prop.centroid for prop in props]).astype("float32")
# transform pixel distance to physical units
coordinates = coordinates * resolution
bb_min = np.array([prop.bbox[:3] for prop in props]).astype("float32") * resolution
bb_max = np.array([prop.bbox[3:] for prop in props]).astype("float32") * resolution
sizes = np.array([prop.area for prop in props])
table = pd.DataFrame({
"label_id": label_ids,
"n_pixels": sizes,
"anchor_x": coordinates[:, 2],
"anchor_y": coordinates[:, 1],
"anchor_z": coordinates[:, 0],
"bb_min_x": bb_min[:, 2],
"bb_min_y": bb_min[:, 1],
"bb_min_z": bb_min[:, 0],
"bb_max_x": bb_max[:, 2],
"bb_max_y": bb_max[:, 1],
"bb_max_z": bb_max[:, 0],
"n_pixels": sizes,
})
return table

Expand Down Expand Up @@ -160,13 +180,12 @@ def filter_segmentation(
spatial_statistics_kwargs: Arguments for spatial statistics function

Returns:
n_ids
n_ids_filtered
The number of objects before filtering.
The number of objects after filtering.
"""
# Compute the table on the fly.
# NOTE: this currently doesn't work for large segmentations.
# Compute the table on the fly. This doesn't work for large segmentations.
if table is None:
table = _compute_table(segmentation, resolution=resolution)
table = compute_table_on_the_fly(segmentation, resolution=resolution)
n_ids = len(table)

# First apply the size filter.
Expand Down
5 changes: 4 additions & 1 deletion flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def run_unet_prediction(
scale: Optional[float] = None,
block_shape: Optional[Tuple[int, int, int]] = None,
halo: Optional[Tuple[int, int, int]] = None,
use_mask: bool = True,
) -> None:
"""Run prediction and segmentation with a distance U-Net.

Expand All @@ -337,10 +338,12 @@ def run_unet_prediction(
By default the data will not be rescaled.
block_shape: The block-shape for running the prediction.
halo: The halo (= block overlap) to use for prediction.
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
"""
os.makedirs(output_folder, exist_ok=True)

find_mask(input_path, input_key, output_folder)
if use_mask:
find_mask(input_path, input_key, output_folder)

original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo
Expand Down
70 changes: 69 additions & 1 deletion flamingo_tools/test_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,75 @@
import os
from typing import Tuple

import imageio.v3 as imageio
from skimage.data import binary_blobs
import requests
from skimage.data import binary_blobs, cells3d
from skimage.measure import label

from .segmentation.postprocessing import compute_table_on_the_fly

SEGMENTATION_URL = "https://owncloud.gwdg.de/index.php/s/kwoGRYiJRRrswgw/download"


def get_test_volume_and_segmentation(folder: str) -> Tuple[str, str, str]:
"""Download a small volume with nuclei and corresponding segmentation.

Args:
folder: The test data folder. The data will be downloaded to this folder.

Returns:
The path to the image, stored as tif.
The path to the segmentation, stored as tif.
The path to the segmentation table, stored as tsv.
"""
os.makedirs(folder, exist_ok=True)

segmentation_path = os.path.join(folder, "segmentation.tif")
resp = requests.get(SEGMENTATION_URL)
resp.raise_for_status()

with open(segmentation_path, "wb") as f:
f.write(resp.content)

nuclei = cells3d()[20:40, 1]
segmentation = imageio.imread(segmentation_path)
assert nuclei.shape == segmentation.shape

image_path = os.path.join(folder, "image.tif")
imageio.imwrite(image_path, nuclei)

table_path = os.path.join(folder, "default.tsv")
table = compute_table_on_the_fly(segmentation, resolution=0.38)
table.to_csv(table_path, sep="\t", index=False)

return image_path, segmentation_path, table_path


def create_image_data_and_segmentation(folder: str, size: int = 256) -> Tuple[str, str, str]:
"""Create test data containing an image, a corresponding segmentation and segmentation table.

Args:
folder: The test data folder. The data will be written to this folder.

Returns:
The path to the image, stored as tif.
The path to the segmentation, stored as tif.
The path to the segmentation table, stored as tsv.
"""
os.makedirs(folder, exist_ok=True)
data = binary_blobs(size, n_dim=3).astype("uint8") * 255
seg = label(data)

image_path = os.path.join(folder, "image.tif")
segmentation_path = os.path.join(folder, "segmentation.tif")
imageio.imwrite(image_path, data)
imageio.imwrite(segmentation_path, seg)

table_path = os.path.join(folder, "default.tsv")
table = compute_table_on_the_fly(seg, resolution=0.38)
table.to_csv(table_path, sep="\t", index=False)

return image_path, segmentation_path, table_path


# TODO add metadata
Expand Down
48 changes: 48 additions & 0 deletions scripts/measurements/measure_sgns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
import os

import numpy as np
import pandas as pd
from flamingo_tools.s3_utils import create_s3_target, BUCKET_NAME


def open_json(fs, path):
s3_path = os.path.join(BUCKET_NAME, path)
with fs.open(s3_path, "r") as f:
content = json.load(f)
return content


def open_tsv(fs, path):
s3_path = os.path.join(BUCKET_NAME, path)
with fs.open(s3_path, "r") as f:
table = pd.read_csv(f, sep="\t")
return table


def main():
fs = create_s3_target()
project_info = open_json(fs, "project.json")
for dataset in project_info["datasets"]:
if dataset == "fens":
continue
print(dataset)
dataset_info = open_json(fs, os.path.join(dataset, "dataset.json"))
sources = dataset_info["sources"]
for source, source_info in sources.items():
if not source.startswith("SGN"):
continue
assert "segmentation" in source_info
source_info = source_info["segmentation"]
table_path = source_info["tableData"]["tsv"]["relativePath"]
table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv"))
component_labels = table.component_labels.values
remaining_sgns = component_labels[component_labels != 0]
print(source)
print("Number of SGNs (all components) :", len(remaining_sgns))
_, n_per_component = np.unique(remaining_sgns, return_counts=True)
print("Number of SGNs (largest component):", max(n_per_component))


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions scripts/sgn_stain_predictions/check_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from glob import glob

import imageio.v3 as imageio
import napari


ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations"


def main():
files = sorted(glob(os.path.join(ROOT, "**/*.tif")))
for ff in files:
if "segmentations" in ff:
return
print("Visualizing", ff)
rel_path = os.path.relpath(ff, ROOT)
seg_path = os.path.join(SAVE_ROOT, rel_path)

image = imageio.imread(ff)
if os.path.exists(seg_path):
seg = imageio.imread(seg_path)
else:
seg = None

v = napari.Viewer()
v.add_image(image)
if seg is not None:
v.add_labels(seg)
napari.run()


if __name__ == "__main__":
main()
Loading