22import os
33from concurrent import futures
44from functools import partial
5- from typing import List , Optional
5+ from typing import List , Optional , Tuple
66
77import numpy as np
88import pandas as pd
99import trimesh
10+ from elf .io import open_file
11+ from elf .wrapper .resized_volume import ResizedVolume
12+ from nifty .tools import blocking
1013from skimage .measure import marching_cubes , regionprops_table
14+ from scipy .ndimage import binary_dilation
1115from tqdm import tqdm
1216
1317from .file_utils import read_image_data
@@ -29,9 +33,14 @@ def _measure_volume_and_surface(mask, resolution):
2933 return volume , surface
3034
3135
32- def _get_bounding_box_and_center (table , seg_id , resolution , shape ):
36+ def _get_bounding_box_and_center (table , seg_id , resolution , shape , dilation ):
3337 row = table [table .label_id == seg_id ]
3438
39+ if dilation is not None and dilation > 0 :
40+ bb_extension = dilation + 1
41+ else :
42+ bb_extension = 1
43+
3544 bb_min = np .array ([
3645 row .bb_min_z .item (), row .bb_min_y .item (), row .bb_min_x .item ()
3746 ]).astype ("float32" ) / resolution
@@ -43,7 +52,7 @@ def _get_bounding_box_and_center(table, seg_id, resolution, shape):
4352 bb_max = np .round (bb_max , 0 ).astype ("int32" )
4453
4554 bb = tuple (
46- slice (max (bmin - 1 , 0 ), min (bmax + 1 , sh ))
55+ slice (max (bmin - bb_extension , 0 ), min (bmax + bb_extension , sh ))
4756 for bmin , bmax , sh in zip (bb_min , bb_max , shape )
4857 )
4958
@@ -115,13 +124,15 @@ def _normalize_background(measures, image, mask, center, radius, norm, median_on
115124
116125def _default_object_features (
117126 seg_id , table , image , segmentation , resolution ,
118- foreground_mask = None , background_radius = None , norm = np .divide , median_only = False ,
127+ background_mask = None , background_radius = None , norm = np .divide , median_only = False , dilation = None
119128):
120- bb , center = _get_bounding_box_and_center (table , seg_id , resolution , image .shape )
129+ bb , center = _get_bounding_box_and_center (table , seg_id , resolution , image .shape , dilation )
121130
122131 local_image = image [bb ]
123132 mask = segmentation [bb ] == seg_id
124133 assert mask .sum () > 0 , f"Segmentation ID { seg_id } is empty."
134+ if dilation is not None and dilation > 0 :
135+ mask = binary_dilation (mask , iterations = dilation )
125136 masked_intensity = local_image [mask ]
126137
127138 # Do the base intensity measurements.
@@ -141,7 +152,7 @@ def _default_object_features(
141152 # The resolution is given in micrometer per pixel.
142153 # So we have to divide by the resolution to obtain the radius in pixel.
143154 radius_in_pixel = background_radius / resolution
144- measures = _normalize_background (measures , image , foreground_mask , center , radius_in_pixel , norm , median_only )
155+ measures = _normalize_background (measures , image , background_mask , center , radius_in_pixel , norm , median_only )
145156
146157 # Do the volume and surface measurement.
147158 if not median_only :
@@ -151,13 +162,15 @@ def _default_object_features(
151162 return measures
152163
153164
154- def _regionprops_features (seg_id , table , image , segmentation , resolution , foreground_mask = None ):
155- bb , _ = _get_bounding_box_and_center (table , seg_id , resolution , image .shape )
165+ def _regionprops_features (seg_id , table , image , segmentation , resolution , background_mask = None , dilation = None ):
166+ bb , _ = _get_bounding_box_and_center (table , seg_id , resolution , image .shape , dilation )
156167
157168 local_image = image [bb ]
158169 local_segmentation = segmentation [bb ]
159170 mask = local_segmentation == seg_id
160171 assert mask .sum () > 0 , f"Segmentation ID { seg_id } is empty."
172+ if dilation is not None and dilation > 0 :
173+ mask = binary_dilation (mask , iterations = dilation )
161174 local_segmentation [~ mask ] = 0
162175
163176 features = regionprops_table (
@@ -196,16 +209,16 @@ def _regionprops_features(seg_id, table, image, segmentation, resolution, foregr
196209"""
197210
198211
199- # TODO integrate segmentation post-processing, see `_extend_sgns_simple` in `gfp_annotation.py`
200212def compute_object_measures_impl (
201213 image : np .typing .ArrayLike ,
202214 segmentation : np .typing .ArrayLike ,
203215 n_threads : Optional [int ] = None ,
204216 resolution : float = 0.38 ,
205217 table : Optional [pd .DataFrame ] = None ,
206218 feature_set : str = "default" ,
207- foreground_mask : Optional [np .typing .ArrayLike ] = None ,
219+ background_mask : Optional [np .typing .ArrayLike ] = None ,
208220 median_only : bool = False ,
221+ dilation : Optional [int ] = None ,
209222) -> pd .DataFrame :
210223 """Compute simple intensity and morphology measures for each segmented cell in a segmentation.
211224
@@ -218,8 +231,10 @@ def compute_object_measures_impl(
218231 resolution: The resolution / voxel size of the data.
219232 table: The segmentation table. Will be computed on the fly if it is not given.
220233 feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
221- foreground_mask : An optional mask indicating the area to use for computing background correction values.
234+ background_mask : An optional mask indicating the area to use for computing background correction values.
222235 median_only: Whether to only compute the median intensity.
236+ dilation: Value for dilating the segmentation before computing measurements.
237+ By default no dilation is applied.
223238
224239 Returns:
225240 The table with per object measurements.
@@ -235,8 +250,9 @@ def compute_object_measures_impl(
235250 image = image ,
236251 segmentation = segmentation ,
237252 resolution = resolution ,
238- foreground_mask = foreground_mask ,
253+ background_mask = background_mask ,
239254 median_only = median_only ,
255+ dilation = dilation ,
240256 )
241257
242258 seg_ids = table .label_id .values
@@ -246,6 +262,7 @@ def compute_object_measures_impl(
246262
247263 # For debugging.
248264 # measure_function(seg_ids[0])
265+ # breakpoint()
249266
250267 with futures .ThreadPoolExecutor (n_threads ) as pool :
251268 measures = list (tqdm (
@@ -272,6 +289,9 @@ def compute_object_measures(
272289 feature_set : str = "default" ,
273290 s3_flag : bool = False ,
274291 component_list : List [int ] = [],
292+ dilation : Optional [int ] = None ,
293+ median_only : bool = False ,
294+ background_mask : Optional [np .typing .ArrayLike ] = None ,
275295) -> None :
276296 """Compute simple intensity and morphology measures for each segmented cell in a segmentation.
277297
@@ -291,6 +311,12 @@ def compute_object_measures(
291311 resolution: The resolution / voxel size of the data.
292312 force: Whether to overwrite an existing output table.
293313 feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
314+ s3_flag:
315+ component_list:
316+ median_only: Whether to only compute the median intensity.
317+ dilation: Value for dilating the segmentation before computing measurements.
318+ By default no dilation is applied.
319+ background_mask: An optional mask indicating the area to use for computing background correction values.
294320 """
295321 if os .path .exists (output_table_path ) and not force :
296322 return
@@ -315,5 +341,92 @@ def compute_object_measures(
315341
316342 measures = compute_object_measures_impl (
317343 image , segmentation , n_threads , resolution , table = table , feature_set = feature_set ,
344+ median_only = median_only , dilation = dilation , background_mask = background_mask ,
318345 )
319346 measures .to_csv (output_table_path , sep = "\t " , index = False )
347+
348+
349+ def compute_sgn_background_mask (
350+ image_path : str ,
351+ segmentation_path : str ,
352+ image_key : Optional [str ] = None ,
353+ segmentation_key : Optional [str ] = None ,
354+ threshold_percentile : float = 35.0 ,
355+ scale_factor : Tuple [int , int , int ] = (16 , 16 , 16 ),
356+ n_threads : Optional [int ] = None ,
357+ cache_path : Optional [str ] = None ,
358+ ) -> np .typing .ArrayLike :
359+ """Compute the background mask for intensity measurements in the SGN segmentation.
360+
361+ This function computes a mask for determining the background signal in the rosenthal canal.
362+ It is computed by downsampling the image (PV) and segmentation (SGNs) internally,
363+ by thresholding the downsampled image, and by then intersecting this mask with the segmentation.
364+ This results in a mask that is positive for the background signal within the rosenthal canal.
365+
366+ Args:
367+ image_path: The path to the image data with the PV channel.
368+ segmentation_path: The path to the SGN segmentation.
369+ image_key: Internal path for the image data, for zarr or similar file formats.
370+ segmentation_key: Internal path for the segmentation data, for zarr or similar file formats.
371+ threshold_percentile: The percentile threshold for separating foreground and background in the PV signal.
372+ scale_factor: The scale factor for internally downsampling the mask.
373+ n_threads: The number of threads for parallelizing the computation.
374+ cache_path: Optional path to save the downscaled background mask to zarr.
375+
376+ Returns:
377+ The mask for determining the background values.
378+ """
379+ image = read_image_data (image_path , image_key )
380+ segmentation = read_image_data (segmentation_path , segmentation_key )
381+ assert image .shape == segmentation .shape
382+
383+ if cache_path is not None and os .path .exists (cache_path ):
384+ with open_file (cache_path , "r" ) as f :
385+ if "mask" in f :
386+ low_res_mask = f ["mask" ][:]
387+ mask = ResizedVolume (low_res_mask , shape = image .shape , order = 0 )
388+ return mask
389+
390+ original_shape = image .shape
391+ downsampled_shape = tuple (int (np .round (sh / sf )) for sh , sf in zip (original_shape , scale_factor ))
392+
393+ low_res_mask = np .zeros (downsampled_shape , dtype = "bool" )
394+
395+ # This corresponds to a block shape of 128 x 512 x 512 in the original resolution,
396+ # which roughly corresponds to the size of the blocks we use for the GFP annotation.
397+ chunk_shape = (8 , 32 , 32 )
398+
399+ blocks = blocking ((0 , 0 , 0 ), downsampled_shape , chunk_shape )
400+ n_blocks = blocks .numberOfBlocks
401+
402+ img_resized = ResizedVolume (image , downsampled_shape )
403+ seg_resized = ResizedVolume (segmentation , downsampled_shape , order = 0 )
404+
405+ def _compute_block (block_id ):
406+ block = blocks .getBlock (block_id )
407+ bb = tuple (slice (beg , end ) for beg , end in zip (block .begin , block .end ))
408+
409+ img = img_resized [bb ]
410+ threshold = np .percentile (img , threshold_percentile )
411+
412+ this_mask = img > threshold
413+ this_seg = seg_resized [bb ] != 0
414+ this_seg = binary_dilation (this_seg )
415+ this_mask [this_seg ] = 0
416+
417+ low_res_mask [bb ] = this_mask
418+
419+ n_threads = mp .cpu_count () if n_threads is None else n_threads
420+ randomized_blocks = np .arange (0 , n_blocks )
421+ np .random .shuffle (randomized_blocks )
422+ with futures .ThreadPoolExecutor (n_threads ) as tp :
423+ list (tqdm (
424+ tp .map (_compute_block , randomized_blocks ), total = n_blocks , desc = "Compute background mask"
425+ ))
426+
427+ if cache_path is not None :
428+ with open_file (cache_path , "a" ) as f :
429+ f .create_dataset ("mask" , data = low_res_mask , chunks = (64 , 64 , 64 ))
430+
431+ mask = ResizedVolume (low_res_mask , shape = original_shape , order = 0 )
432+ return mask
0 commit comments