Skip to content

Commit aee6121

Browse files
Add script for measuring the number of SGNs (#27)
Add intensity measurement functionality and scripts for segmentation and measurement of new SGN stainings
1 parent 68fba2f commit aee6121

File tree

11 files changed

+517
-11
lines changed

11 files changed

+517
-11
lines changed

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- pytorch
1313
- s3fs
1414
- torch_em
15+
- trimesh
1516
- z5py
1617
# Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet
1718
- zarr <3

flamingo_tools/measurements.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import multiprocessing as mp
2+
from concurrent import futures
3+
from typing import Optional
4+
5+
import numpy as np
6+
import pandas as pd
7+
import trimesh
8+
from skimage.measure import marching_cubes
9+
from tqdm import tqdm
10+
11+
from .file_utils import read_image_data
12+
from .segmentation.postprocessing import compute_table_on_the_fly
13+
14+
15+
def _measure_volume_and_surface(mask, resolution):
16+
# Use marching_cubes for 3D data
17+
verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3)
18+
19+
mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
20+
surface = mesh.area
21+
if mesh.is_watertight:
22+
volume = np.abs(mesh.volume)
23+
else:
24+
volume = np.nan
25+
26+
return volume, surface
27+
28+
29+
def compute_object_measures_impl(
30+
image: np.typing.ArrayLike,
31+
segmentation: np.typing.ArrayLike,
32+
n_threads: Optional[int] = None,
33+
resolution: float = 0.38,
34+
table: Optional[pd.DataFrame] = None,
35+
) -> pd.DataFrame:
36+
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
37+
38+
See `compute_object_measures` for details.
39+
40+
Args:
41+
image: The image data.
42+
segmentation: The segmentation.
43+
n_threads: The number of threads to use for computation.
44+
resolution: The resolution / voxel size of the data.
45+
table: The segmentation table. Will be computed on the fly if it is not given.
46+
47+
Returns:
48+
The table with per object measurements.
49+
"""
50+
if table is None:
51+
table = compute_table_on_the_fly(segmentation, resolution=resolution)
52+
53+
def intensity_measures(seg_id):
54+
# Get the bounding box.
55+
row = table[table.label_id == seg_id]
56+
57+
bb_min = np.array([
58+
row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item()
59+
]).astype("float32") / resolution
60+
bb_min = np.round(bb_min, 0).astype("int32")
61+
62+
bb_max = np.array([
63+
row.bb_max_z.item(), row.bb_max_y.item(), row.bb_max_x.item()
64+
]).astype("float32") / resolution
65+
bb_max = np.round(bb_max, 0).astype("int32")
66+
67+
bb = tuple(
68+
slice(max(bmin - 1, 0), min(bmax + 1, sh))
69+
for bmin, bmax, sh in zip(bb_min, bb_max, image.shape)
70+
)
71+
72+
local_image = image[bb]
73+
mask = segmentation[bb] == seg_id
74+
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
75+
masked_intensity = local_image[mask]
76+
77+
# Do the base intensity measurements.
78+
measures = {
79+
"label_id": seg_id,
80+
"mean": np.mean(masked_intensity),
81+
"stdev": np.std(masked_intensity),
82+
"min": np.min(masked_intensity),
83+
"max": np.max(masked_intensity),
84+
"median": np.median(masked_intensity),
85+
}
86+
for percentile in (5, 10, 25, 75, 90, 95):
87+
measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile)
88+
89+
# Do the volume and surface measurement.
90+
volume, surface = _measure_volume_and_surface(mask, resolution)
91+
measures["volume"] = volume
92+
measures["surface"] = surface
93+
return measures
94+
95+
seg_ids = table.label_id.values
96+
assert len(seg_ids) > 0, "The segmentation table is empty."
97+
n_threads = mp.cpu_count() if n_threads is None else n_threads
98+
with futures.ThreadPoolExecutor(n_threads) as pool:
99+
measures = list(tqdm(
100+
pool.map(intensity_measures, seg_ids), total=len(seg_ids), desc="Compute intensity measures"
101+
))
102+
103+
# Create the result table and save it.
104+
keys = measures[0].keys()
105+
measures = pd.DataFrame({k: [measure[k] for measure in measures] for k in keys})
106+
return measures
107+
108+
109+
# Could also support s3 directly?
110+
def compute_object_measures(
111+
image_path: str,
112+
segmentation_path: str,
113+
segmentation_table_path: str,
114+
output_table_path: str,
115+
image_key: Optional[str] = None,
116+
segmentation_key: Optional[str] = None,
117+
n_threads: Optional[int] = None,
118+
resolution: float = 0.38,
119+
) -> None:
120+
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
121+
122+
This computes the mean, standard deviation, minimum, maximum, median and
123+
5th, 10th, 25th, 75th, 90th and 95th percentile of the intensity image
124+
per cell, as well as the volume and surface.
125+
126+
Args:
127+
image_path: The filepath to the image data. Either a tif or hdf5/zarr/n5 file.
128+
segmentation_path: The filepath to the segmentation data. Either a tif or hdf5/zarr/n5 file.
129+
segmentation_table_path: The path to the segmentation table in MoBIE format.
130+
output_table_path: The path for saving the segmentation with intensity measures.
131+
image_key: The key (= internal path) for the image data. Not needed fir tif.
132+
segmentation_key: The key (= internal path) for the segmentation data. Not needed for tif.
133+
n_threads: The number of threads to use for computation.
134+
resolution: The resolution / voxel size of the data.
135+
"""
136+
# First, we load the pre-computed segmentation table from MoBIE.
137+
table = pd.read_csv(segmentation_table_path, sep="\t")
138+
139+
# Then, open the volumes.
140+
image = read_image_data(image_path, image_key)
141+
segmentation = read_image_data(segmentation_path, segmentation_key)
142+
143+
measures = compute_object_measures_impl(
144+
image, segmentation, n_threads, resolution, table=table
145+
)
146+
measures.to_csv(output_table_path, sep="\t", index=False)

flamingo_tools/segmentation/postprocessing.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,39 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray:
115115
#
116116

117117

118-
def _compute_table(segmentation, resolution):
118+
def compute_table_on_the_fly(segmentation: np.typing.ArrayLike, resolution: float) -> pd.DataFrame:
119+
"""Compute a segmentation table compatible with MoBIE.
120+
121+
The table contains information about the number of pixels per object,
122+
the anchor (= centroid) and the bounding box. Anchor and bounding box are given in physical coordinates.
123+
124+
Args:
125+
segmentation: The segmentation for which to compute the table.
126+
resolution: The physical voxel spacing of the data.
127+
128+
Returns:
129+
The segmentation table.
130+
"""
119131
props = measure.regionprops(segmentation)
120132
label_ids = np.array([prop.label for prop in props])
121-
coordinates = np.array([prop.centroid for prop in props])
133+
coordinates = np.array([prop.centroid for prop in props]).astype("float32")
122134
# transform pixel distance to physical units
123135
coordinates = coordinates * resolution
136+
bb_min = np.array([prop.bbox[:3] for prop in props]).astype("float32") * resolution
137+
bb_max = np.array([prop.bbox[3:] for prop in props]).astype("float32") * resolution
124138
sizes = np.array([prop.area for prop in props])
125139
table = pd.DataFrame({
126140
"label_id": label_ids,
127-
"n_pixels": sizes,
128141
"anchor_x": coordinates[:, 2],
129142
"anchor_y": coordinates[:, 1],
130143
"anchor_z": coordinates[:, 0],
144+
"bb_min_x": bb_min[:, 2],
145+
"bb_min_y": bb_min[:, 1],
146+
"bb_min_z": bb_min[:, 0],
147+
"bb_max_x": bb_max[:, 2],
148+
"bb_max_y": bb_max[:, 1],
149+
"bb_max_z": bb_max[:, 0],
150+
"n_pixels": sizes,
131151
})
132152
return table
133153

@@ -160,13 +180,12 @@ def filter_segmentation(
160180
spatial_statistics_kwargs: Arguments for spatial statistics function
161181
162182
Returns:
163-
n_ids
164-
n_ids_filtered
183+
The number of objects before filtering.
184+
The number of objects after filtering.
165185
"""
166-
# Compute the table on the fly.
167-
# NOTE: this currently doesn't work for large segmentations.
186+
# Compute the table on the fly. This doesn't work for large segmentations.
168187
if table is None:
169-
table = _compute_table(segmentation, resolution=resolution)
188+
table = compute_table_on_the_fly(segmentation, resolution=resolution)
170189
n_ids = len(table)
171190

172191
# First apply the size filter.

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def run_unet_prediction(
324324
scale: Optional[float] = None,
325325
block_shape: Optional[Tuple[int, int, int]] = None,
326326
halo: Optional[Tuple[int, int, int]] = None,
327+
use_mask: bool = True,
327328
) -> None:
328329
"""Run prediction and segmentation with a distance U-Net.
329330
@@ -337,10 +338,12 @@ def run_unet_prediction(
337338
By default the data will not be rescaled.
338339
block_shape: The block-shape for running the prediction.
339340
halo: The halo (= block overlap) to use for prediction.
341+
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
340342
"""
341343
os.makedirs(output_folder, exist_ok=True)
342344

343-
find_mask(input_path, input_key, output_folder)
345+
if use_mask:
346+
find_mask(input_path, input_key, output_folder)
344347

345348
original_shape = prediction_impl(
346349
input_path, input_key, output_folder, model_path, scale, block_shape, halo

flamingo_tools/test_data.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,75 @@
11
import os
2+
from typing import Tuple
23

34
import imageio.v3 as imageio
4-
from skimage.data import binary_blobs
5+
import requests
6+
from skimage.data import binary_blobs, cells3d
7+
from skimage.measure import label
8+
9+
from .segmentation.postprocessing import compute_table_on_the_fly
10+
11+
SEGMENTATION_URL = "https://owncloud.gwdg.de/index.php/s/kwoGRYiJRRrswgw/download"
12+
13+
14+
def get_test_volume_and_segmentation(folder: str) -> Tuple[str, str, str]:
15+
"""Download a small volume with nuclei and corresponding segmentation.
16+
17+
Args:
18+
folder: The test data folder. The data will be downloaded to this folder.
19+
20+
Returns:
21+
The path to the image, stored as tif.
22+
The path to the segmentation, stored as tif.
23+
The path to the segmentation table, stored as tsv.
24+
"""
25+
os.makedirs(folder, exist_ok=True)
26+
27+
segmentation_path = os.path.join(folder, "segmentation.tif")
28+
resp = requests.get(SEGMENTATION_URL)
29+
resp.raise_for_status()
30+
31+
with open(segmentation_path, "wb") as f:
32+
f.write(resp.content)
33+
34+
nuclei = cells3d()[20:40, 1]
35+
segmentation = imageio.imread(segmentation_path)
36+
assert nuclei.shape == segmentation.shape
37+
38+
image_path = os.path.join(folder, "image.tif")
39+
imageio.imwrite(image_path, nuclei)
40+
41+
table_path = os.path.join(folder, "default.tsv")
42+
table = compute_table_on_the_fly(segmentation, resolution=0.38)
43+
table.to_csv(table_path, sep="\t", index=False)
44+
45+
return image_path, segmentation_path, table_path
46+
47+
48+
def create_image_data_and_segmentation(folder: str, size: int = 256) -> Tuple[str, str, str]:
49+
"""Create test data containing an image, a corresponding segmentation and segmentation table.
50+
51+
Args:
52+
folder: The test data folder. The data will be written to this folder.
53+
54+
Returns:
55+
The path to the image, stored as tif.
56+
The path to the segmentation, stored as tif.
57+
The path to the segmentation table, stored as tsv.
58+
"""
59+
os.makedirs(folder, exist_ok=True)
60+
data = binary_blobs(size, n_dim=3).astype("uint8") * 255
61+
seg = label(data)
62+
63+
image_path = os.path.join(folder, "image.tif")
64+
segmentation_path = os.path.join(folder, "segmentation.tif")
65+
imageio.imwrite(image_path, data)
66+
imageio.imwrite(segmentation_path, seg)
67+
68+
table_path = os.path.join(folder, "default.tsv")
69+
table = compute_table_on_the_fly(seg, resolution=0.38)
70+
table.to_csv(table_path, sep="\t", index=False)
71+
72+
return image_path, segmentation_path, table_path
573

674

775
# TODO add metadata
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
import os
3+
4+
import numpy as np
5+
import pandas as pd
6+
from flamingo_tools.s3_utils import create_s3_target, BUCKET_NAME
7+
8+
9+
def open_json(fs, path):
10+
s3_path = os.path.join(BUCKET_NAME, path)
11+
with fs.open(s3_path, "r") as f:
12+
content = json.load(f)
13+
return content
14+
15+
16+
def open_tsv(fs, path):
17+
s3_path = os.path.join(BUCKET_NAME, path)
18+
with fs.open(s3_path, "r") as f:
19+
table = pd.read_csv(f, sep="\t")
20+
return table
21+
22+
23+
def main():
24+
fs = create_s3_target()
25+
project_info = open_json(fs, "project.json")
26+
for dataset in project_info["datasets"]:
27+
if dataset == "fens":
28+
continue
29+
print(dataset)
30+
dataset_info = open_json(fs, os.path.join(dataset, "dataset.json"))
31+
sources = dataset_info["sources"]
32+
for source, source_info in sources.items():
33+
if not source.startswith("SGN"):
34+
continue
35+
assert "segmentation" in source_info
36+
source_info = source_info["segmentation"]
37+
table_path = source_info["tableData"]["tsv"]["relativePath"]
38+
table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv"))
39+
component_labels = table.component_labels.values
40+
remaining_sgns = component_labels[component_labels != 0]
41+
print(source)
42+
print("Number of SGNs (all components) :", len(remaining_sgns))
43+
_, n_per_component = np.unique(remaining_sgns, return_counts=True)
44+
print("Number of SGNs (largest component):", max(n_per_component))
45+
46+
47+
if __name__ == "__main__":
48+
main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import napari
6+
7+
8+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
9+
SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations"
10+
11+
12+
def main():
13+
files = sorted(glob(os.path.join(ROOT, "**/*.tif")))
14+
for ff in files:
15+
if "segmentations" in ff:
16+
return
17+
print("Visualizing", ff)
18+
rel_path = os.path.relpath(ff, ROOT)
19+
seg_path = os.path.join(SAVE_ROOT, rel_path)
20+
21+
image = imageio.imread(ff)
22+
if os.path.exists(seg_path):
23+
seg = imageio.imread(seg_path)
24+
else:
25+
seg = None
26+
27+
v = napari.Viewer()
28+
v.add_image(image)
29+
if seg is not None:
30+
v.add_labels(seg)
31+
napari.run()
32+
33+
34+
if __name__ == "__main__":
35+
main()

0 commit comments

Comments
 (0)