Skip to content

Commit 1836e50

Browse files
Merge pull request #56 from computational-cell-analytics/measurement-updates
Update measurement impl
2 parents 10e1a86 + 2ae7bd1 commit 1836e50

21 files changed

+933
-24
lines changed

flamingo_tools/measurements.py

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
import os
33
from concurrent import futures
44
from functools import partial
5-
from typing import List, Optional
5+
from typing import List, Optional, Tuple
66

77
import numpy as np
88
import pandas as pd
99
import trimesh
10+
from elf.io import open_file
11+
from elf.wrapper.resized_volume import ResizedVolume
12+
from nifty.tools import blocking
1013
from skimage.measure import marching_cubes, regionprops_table
14+
from scipy.ndimage import binary_dilation
1115
from tqdm import tqdm
1216

1317
from .file_utils import read_image_data
@@ -29,9 +33,14 @@ def _measure_volume_and_surface(mask, resolution):
2933
return volume, surface
3034

3135

32-
def _get_bounding_box_and_center(table, seg_id, resolution, shape):
36+
def _get_bounding_box_and_center(table, seg_id, resolution, shape, dilation):
3337
row = table[table.label_id == seg_id]
3438

39+
if dilation is not None and dilation > 0:
40+
bb_extension = dilation + 1
41+
else:
42+
bb_extension = 1
43+
3544
bb_min = np.array([
3645
row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item()
3746
]).astype("float32") / resolution
@@ -43,7 +52,7 @@ def _get_bounding_box_and_center(table, seg_id, resolution, shape):
4352
bb_max = np.round(bb_max, 0).astype("int32")
4453

4554
bb = tuple(
46-
slice(max(bmin - 1, 0), min(bmax + 1, sh))
55+
slice(max(bmin - bb_extension, 0), min(bmax + bb_extension, sh))
4756
for bmin, bmax, sh in zip(bb_min, bb_max, shape)
4857
)
4958

@@ -115,13 +124,15 @@ def _normalize_background(measures, image, mask, center, radius, norm, median_on
115124

116125
def _default_object_features(
117126
seg_id, table, image, segmentation, resolution,
118-
foreground_mask=None, background_radius=None, norm=np.divide, median_only=False,
127+
background_mask=None, background_radius=None, norm=np.divide, median_only=False, dilation=None
119128
):
120-
bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)
129+
bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation)
121130

122131
local_image = image[bb]
123132
mask = segmentation[bb] == seg_id
124133
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
134+
if dilation is not None and dilation > 0:
135+
mask = binary_dilation(mask, iterations=dilation)
125136
masked_intensity = local_image[mask]
126137

127138
# Do the base intensity measurements.
@@ -141,7 +152,7 @@ def _default_object_features(
141152
# The resolution is given in micrometer per pixel.
142153
# So we have to divide by the resolution to obtain the radius in pixel.
143154
radius_in_pixel = background_radius / resolution
144-
measures = _normalize_background(measures, image, foreground_mask, center, radius_in_pixel, norm, median_only)
155+
measures = _normalize_background(measures, image, background_mask, center, radius_in_pixel, norm, median_only)
145156

146157
# Do the volume and surface measurement.
147158
if not median_only:
@@ -151,13 +162,15 @@ def _default_object_features(
151162
return measures
152163

153164

154-
def _regionprops_features(seg_id, table, image, segmentation, resolution, foreground_mask=None):
155-
bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)
165+
def _regionprops_features(seg_id, table, image, segmentation, resolution, background_mask=None, dilation=None):
166+
bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation)
156167

157168
local_image = image[bb]
158169
local_segmentation = segmentation[bb]
159170
mask = local_segmentation == seg_id
160171
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
172+
if dilation is not None and dilation > 0:
173+
mask = binary_dilation(mask, iterations=dilation)
161174
local_segmentation[~mask] = 0
162175

163176
features = regionprops_table(
@@ -196,16 +209,16 @@ def _regionprops_features(seg_id, table, image, segmentation, resolution, foregr
196209
"""
197210

198211

199-
# TODO integrate segmentation post-processing, see `_extend_sgns_simple` in `gfp_annotation.py`
200212
def compute_object_measures_impl(
201213
image: np.typing.ArrayLike,
202214
segmentation: np.typing.ArrayLike,
203215
n_threads: Optional[int] = None,
204216
resolution: float = 0.38,
205217
table: Optional[pd.DataFrame] = None,
206218
feature_set: str = "default",
207-
foreground_mask: Optional[np.typing.ArrayLike] = None,
219+
background_mask: Optional[np.typing.ArrayLike] = None,
208220
median_only: bool = False,
221+
dilation: Optional[int] = None,
209222
) -> pd.DataFrame:
210223
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
211224
@@ -218,8 +231,10 @@ def compute_object_measures_impl(
218231
resolution: The resolution / voxel size of the data.
219232
table: The segmentation table. Will be computed on the fly if it is not given.
220233
feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
221-
foreground_mask: An optional mask indicating the area to use for computing background correction values.
234+
background_mask: An optional mask indicating the area to use for computing background correction values.
222235
median_only: Whether to only compute the median intensity.
236+
dilation: Value for dilating the segmentation before computing measurements.
237+
By default no dilation is applied.
223238
224239
Returns:
225240
The table with per object measurements.
@@ -235,8 +250,9 @@ def compute_object_measures_impl(
235250
image=image,
236251
segmentation=segmentation,
237252
resolution=resolution,
238-
foreground_mask=foreground_mask,
253+
background_mask=background_mask,
239254
median_only=median_only,
255+
dilation=dilation,
240256
)
241257

242258
seg_ids = table.label_id.values
@@ -246,6 +262,7 @@ def compute_object_measures_impl(
246262

247263
# For debugging.
248264
# measure_function(seg_ids[0])
265+
# breakpoint()
249266

250267
with futures.ThreadPoolExecutor(n_threads) as pool:
251268
measures = list(tqdm(
@@ -272,6 +289,9 @@ def compute_object_measures(
272289
feature_set: str = "default",
273290
s3_flag: bool = False,
274291
component_list: List[int] = [],
292+
dilation: Optional[int] = None,
293+
median_only: bool = False,
294+
background_mask: Optional[np.typing.ArrayLike] = None,
275295
) -> None:
276296
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
277297
@@ -291,6 +311,12 @@ def compute_object_measures(
291311
resolution: The resolution / voxel size of the data.
292312
force: Whether to overwrite an existing output table.
293313
feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
314+
s3_flag:
315+
component_list:
316+
median_only: Whether to only compute the median intensity.
317+
dilation: Value for dilating the segmentation before computing measurements.
318+
By default no dilation is applied.
319+
background_mask: An optional mask indicating the area to use for computing background correction values.
294320
"""
295321
if os.path.exists(output_table_path) and not force:
296322
return
@@ -315,5 +341,92 @@ def compute_object_measures(
315341

316342
measures = compute_object_measures_impl(
317343
image, segmentation, n_threads, resolution, table=table, feature_set=feature_set,
344+
median_only=median_only, dilation=dilation, background_mask=background_mask,
318345
)
319346
measures.to_csv(output_table_path, sep="\t", index=False)
347+
348+
349+
def compute_sgn_background_mask(
350+
image_path: str,
351+
segmentation_path: str,
352+
image_key: Optional[str] = None,
353+
segmentation_key: Optional[str] = None,
354+
threshold_percentile: float = 35.0,
355+
scale_factor: Tuple[int, int, int] = (16, 16, 16),
356+
n_threads: Optional[int] = None,
357+
cache_path: Optional[str] = None,
358+
) -> np.typing.ArrayLike:
359+
"""Compute the background mask for intensity measurements in the SGN segmentation.
360+
361+
This function computes a mask for determining the background signal in the rosenthal canal.
362+
It is computed by downsampling the image (PV) and segmentation (SGNs) internally,
363+
by thresholding the downsampled image, and by then intersecting this mask with the segmentation.
364+
This results in a mask that is positive for the background signal within the rosenthal canal.
365+
366+
Args:
367+
image_path: The path to the image data with the PV channel.
368+
segmentation_path: The path to the SGN segmentation.
369+
image_key: Internal path for the image data, for zarr or similar file formats.
370+
segmentation_key: Internal path for the segmentation data, for zarr or similar file formats.
371+
threshold_percentile: The percentile threshold for separating foreground and background in the PV signal.
372+
scale_factor: The scale factor for internally downsampling the mask.
373+
n_threads: The number of threads for parallelizing the computation.
374+
cache_path: Optional path to save the downscaled background mask to zarr.
375+
376+
Returns:
377+
The mask for determining the background values.
378+
"""
379+
image = read_image_data(image_path, image_key)
380+
segmentation = read_image_data(segmentation_path, segmentation_key)
381+
assert image.shape == segmentation.shape
382+
383+
if cache_path is not None and os.path.exists(cache_path):
384+
with open_file(cache_path, "r") as f:
385+
if "mask" in f:
386+
low_res_mask = f["mask"][:]
387+
mask = ResizedVolume(low_res_mask, shape=image.shape, order=0)
388+
return mask
389+
390+
original_shape = image.shape
391+
downsampled_shape = tuple(int(np.round(sh / sf)) for sh, sf in zip(original_shape, scale_factor))
392+
393+
low_res_mask = np.zeros(downsampled_shape, dtype="bool")
394+
395+
# This corresponds to a block shape of 128 x 512 x 512 in the original resolution,
396+
# which roughly corresponds to the size of the blocks we use for the GFP annotation.
397+
chunk_shape = (8, 32, 32)
398+
399+
blocks = blocking((0, 0, 0), downsampled_shape, chunk_shape)
400+
n_blocks = blocks.numberOfBlocks
401+
402+
img_resized = ResizedVolume(image, downsampled_shape)
403+
seg_resized = ResizedVolume(segmentation, downsampled_shape, order=0)
404+
405+
def _compute_block(block_id):
406+
block = blocks.getBlock(block_id)
407+
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
408+
409+
img = img_resized[bb]
410+
threshold = np.percentile(img, threshold_percentile)
411+
412+
this_mask = img > threshold
413+
this_seg = seg_resized[bb] != 0
414+
this_seg = binary_dilation(this_seg)
415+
this_mask[this_seg] = 0
416+
417+
low_res_mask[bb] = this_mask
418+
419+
n_threads = mp.cpu_count() if n_threads is None else n_threads
420+
randomized_blocks = np.arange(0, n_blocks)
421+
np.random.shuffle(randomized_blocks)
422+
with futures.ThreadPoolExecutor(n_threads) as tp:
423+
list(tqdm(
424+
tp.map(_compute_block, randomized_blocks), total=n_blocks, desc="Compute background mask"
425+
))
426+
427+
if cache_path is not None:
428+
with open_file(cache_path, "a") as f:
429+
f.create_dataset("mask", data=low_res_mask, chunks=(64, 64, 64))
430+
431+
mask = ResizedVolume(low_res_mask, shape=original_shape, order=0)
432+
return mask

0 commit comments

Comments
 (0)