1- import numpy as np
2- import vigra
31import multiprocessing as mp
42from concurrent import futures
3+ from typing import Callable , Tuple , Optional
54
6- from skimage import measure
5+ import elf .parallel as parallel
6+ import numpy as np
7+ import nifty .tools as nt
8+ import pandas as pd
9+ import vigra
10+
11+ from elf .io import open_file
712from scipy .spatial import distance
813from scipy .sparse import csr_matrix
9- from tqdm import tqdm
14+ from scipy .spatial import cKDTree , ConvexHull
15+ from skimage import measure
1016from sklearn .neighbors import NearestNeighbors
17+ from tqdm import tqdm
1118
12- import elf .parallel as parallel
13- from elf .io import open_file
14- import nifty .tools as nt
1519
20+ #
21+ # Spatial statistics:
22+ # Three different spatial statistics implementations that
23+ # can be used as the basis of a filtering criterion.
24+ #
1625
17- def distance_nearest_neighbors (tsv_table , n_neighbors = 10 , expand_table = True ):
18- """Calculate average distance of n nearest neighbors.
1926
20- :param DataFrame tsv_table:
21- :param int n_neighbors: Number of nearest neighbors
22- :param bool expand_table: Flag for expanding DataFrame
23- :returns: List of average distances
24- :rtype: list
25- """
26- centroids = list (zip (tsv_table ["anchor_x" ], tsv_table ["anchor_y" ], tsv_table ["anchor_z" ]))
27+ def nearest_neighbor_distance (table : pd .DataFrame , n_neighbors : int = 10 ) -> np .ndarray :
28+ """Compute the average distance to the n nearest neighbors.
29+
30+ Args:
31+ table: The table with the centroid coordinates.
32+ n_neighbors: The number of neighbors to take into account for the distance computation.
2733
28- coordinates = np .array (centroids )
34+ Returns:
35+ The average distances to the n nearest neighbors.
36+ """
37+ centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
38+ centroids = np .array (centroids )
2939
30- # nearest neighbor is always itself, so n_neighbors+=1
31- nbrs = NearestNeighbors (n_neighbors = n_neighbors + 1 ).fit (coordinates )
32- distances , indices = nbrs .kneighbors (coordinates )
40+ # Nearest neighbor is always itself, so n_neighbors+=1.
41+ nbrs = NearestNeighbors (n_neighbors = n_neighbors + 1 ).fit (centroids )
42+ distances , indices = nbrs .kneighbors (centroids )
3343
3444 # Average distance to nearest neighbors
35- distance_avg = [sum (d ) / len (d ) for d in distances [:, 1 :]]
45+ distance_avg = np .array ([sum (d ) / len (d ) for d in distances [:, 1 :]])
46+ return distance_avg
3647
37- if expand_table :
38- tsv_table ['distance_nn' + str (n_neighbors )] = distance_avg
3948
40- return distance_avg
49+ def local_ripleys_k (table : pd .DataFrame , radius : float = 15 , volume : Optional [float ] = None ) -> np .ndarray :
50+ """Compute the local Ripley's K function for each point in a 2D / 3D.
4151
52+ Args:
53+ table: The table with the centroid coordinates.
54+ radius: The radius within which to count neighboring points.
55+ volume: The area (2D) or volume (3D) of the study region. If None, it is estimated from the convex hull.
56+
57+ Returns:
58+ An array containing the local K values for each point.
59+ """
60+ points = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
61+ points = np .array (points )
62+ n_points , dim = points .shape
63+
64+ if dim not in (2 , 3 ):
65+ raise ValueError ("Points array must be of shape (n_points, 2) or (n_points, 3)." )
66+
67+ # Estimate area/volume if not provided.
68+ if volume is None :
69+ hull = ConvexHull (points )
70+ volume = hull .volume # For 2D, 'volume' is area; for 3D, it's volume.
71+
72+ # Compute point density.
73+ density = n_points / volume
74+
75+ # Build a KD-tree for efficient neighbor search.
76+ tree = cKDTree (points )
77+
78+ # Count neighbors within the specified radius for each point
79+ counts = tree .query_ball_point (points , r = radius )
80+ local_counts = np .array ([len (c ) - 1 for c in counts ]) # Exclude the point itself
81+
82+ # Normalize by density to get local K values
83+ local_k = local_counts / density
84+ return local_k
4285
43- def filter_isolated_objects (
44- segmentation , output_path , tsv_table = None ,
45- distance_threshold = 15 , neighbor_threshold = 5 , min_size = 1000 ,
46- output_key = "segmentation_postprocessed" ,
47- ):
48- """Postprocessing step to filter isolated objects from a segmentation.
4986
50- Instance segmentations are filtered if they have fewer neighbors
51- than a given threshold in a given distance around them.
52- Additionally, size filtering is possible if a TSV file is supplied.
53-
54- :param dataset segmentation: Dataset containing the segmentation
55- :param str out_path: Output path for postprocessed segmentation
56- :param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
57- :param int distance_threshold: Distance in micrometer to check for neighbors
58- :param int neighbor_threshold: Minimal number of neighbors for filtering
59- :param int min_size: Minimal number of pixels for filtering small instances
60- :param str output_key: Output key for postprocessed segmentation
87+ def neighbors_in_radius (table : pd .DataFrame , radius : float = 15 ) -> np .ndarray :
88+ """Compute the number of neighbors within a given radius.
89+
90+ Args:
91+ table: The table with the centroid coordinates.
92+ radius: The radius within which to count neighboring points.
93+
94+ Returns:
95+ An array containing the number of neighbors within the given radius.
6196 """
62- if tsv_table is not None :
63- n_pixels = tsv_table ["n_pixels" ].to_list ()
64- label_ids = tsv_table ["label_id" ].to_list ()
65- centroids = list (zip (tsv_table ["anchor_x" ], tsv_table ["anchor_y" ], tsv_table ["anchor_z" ]))
66- n_ids = len (label_ids )
67-
68- # filter out cells smaller than min_size
69- if min_size is not None :
70- min_size_label_ids = [l for (l , n ) in zip (label_ids , n_pixels ) if n <= min_size ]
71- centroids = [c for (c , l ) in zip (centroids , label_ids ) if l not in min_size_label_ids ]
72- label_ids = [int (lid ) for lid in label_ids if lid not in min_size_label_ids ]
73-
74- coordinates = np .array (centroids )
75- label_ids = np .array (label_ids )
76-
77- else :
78- segmentation , n_ids , _ = vigra .analysis .relabelConsecutive (segmentation [:], start_label = 1 , keep_zeros = True )
79- props = measure .regionprops (segmentation )
80- coordinates = np .array ([prop .centroid for prop in props ])
81- label_ids = np .unique (segmentation )[1 :]
82-
83- # Calculate pairwise distances and convert to a square matrix
84- dist_matrix = distance .pdist (coordinates )
97+ points = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
98+ points = np .array (points )
99+
100+ dist_matrix = distance .pdist (points )
85101 dist_matrix = distance .squareform (dist_matrix )
86102
87- # Create sparse matrix of connections within the threshold distance
88- sparse_matrix = csr_matrix (dist_matrix < distance_threshold , dtype = int )
103+ # Create sparse matrix of connections within the threshold distance.
104+ sparse_matrix = csr_matrix (dist_matrix < radius , dtype = int )
89105
90- # Sum each row to count neighbors
106+ # Sum each row to count neighbors.
91107 neighbor_counts = sparse_matrix .sum (axis = 1 )
108+ return np .array (neighbor_counts )
109+
110+
111+ #
112+ # Filtering function:
113+ # Filter the segmentation based on a spatial statistics from above.
114+ #
115+
116+
117+ def _compute_table (segmentation ):
118+ segmentation , n_ids , _ = vigra .analysis .relabelConsecutive (segmentation [:], start_label = 1 , keep_zeros = True )
119+ props = measure .regionprops (segmentation )
120+ coordinates = np .array ([prop .centroid for prop in props ])[1 :]
121+ label_ids = np .unique (segmentation )[1 :]
122+ sizes = np .array ([prop .area for prop in props ])[1 :]
123+ table = pd .DataFrame ({
124+ "label_id" : label_ids ,
125+ "n_pixels" : sizes ,
126+ "anchor_x" : coordinates [:, 2 ],
127+ "anchor_y" : coordinates [:, 1 ],
128+ "anchor_z" : coordinates [:, 0 ],
129+ })
130+ return table
131+
132+
133+ def filter_segmentation (
134+ segmentation : np .typing .ArrayLike ,
135+ output_path : str ,
136+ spatial_statistics : Callable ,
137+ threshold : float ,
138+ min_size : int = 1000 ,
139+ table : Optional [pd .DataFrame ] = None ,
140+ output_key : str = "segmentation_postprocessed" ,
141+ ) -> Tuple [int , int ]:
142+ """Postprocessing step to filter isolated objects from a segmentation.
92143
93- filter_mask = np .array (neighbor_counts < neighbor_threshold ).squeeze ()
94- filter_ids = label_ids [filter_mask ]
144+ Instance segmentations are filtered based on spatial statistics and a threshold.
145+ In addition, objects smaller than a given size are filtered out.
146+
147+ Args:
148+ segmentation: Dataset containing the segmentation
149+ output_path: Output path for postprocessed segmentation
150+ spatial_statistics:
151+ threshold: Distance in micrometer to check for neighbors
152+ min_size: Minimal number of pixels for filtering small instances
153+ table:
154+ output_key: Output key for postprocessed segmentation
155+
156+ Returns:
157+ n_ids
158+ n_ids_filtered
159+ """
160+ # Compute the table on the fly.
161+ # NOTE: this currently doesn't work for large segmentations.
162+ if table is None :
163+ table = _compute_table (segmentation )
164+ n_ids = len (table )
165+
166+ # First apply the size filter.
167+ table = table [table .n_pixels > min_size ]
168+ stat_values = spatial_statistics (table )
169+
170+ keep_mask = np .array (stat_values > threshold ).squeeze ()
171+ keep_ids = table .label_id .values [keep_mask ]
95172
96173 shape = segmentation .shape
97174 block_shape = (128 , 128 , 128 )
@@ -100,7 +177,6 @@ def filter_isolated_objects(
100177 blocking = nt .blocking ([0 ] * len (shape ), shape , block_shape )
101178
102179 output = open_file (output_path , mode = "a" )
103-
104180 output_dataset = output .create_dataset (
105181 output_key , shape = shape , dtype = segmentation .dtype ,
106182 chunks = chunks , compression = "gzip"
@@ -112,17 +188,16 @@ def filter_chunk(block_id):
112188 block = blocking .getBlock (block_id )
113189 volume_index = tuple (slice (beg , end ) for beg , end in zip (block .begin , block .end ))
114190 data = segmentation [volume_index ]
115- data [np .isin (data , filter_ids )] = 0
191+ data [np .isin (data , keep_ids )] = 0
116192 output_dataset [volume_index ] = data
117193
118194 # Limit the number of cores for parallelization.
119195 n_threads = min (16 , mp .cpu_count ())
120-
121196 with futures .ThreadPoolExecutor (n_threads ) as filter_pool :
122197 list (tqdm (filter_pool .map (filter_chunk , range (blocking .numberOfBlocks )), total = blocking .numberOfBlocks ))
123198
124199 seg_filtered , n_ids_filtered , _ = parallel .relabel_consecutive (
125- output_dataset , start_label = 1 , keep_zeros = True , block_shape = ( 128 , 128 , 128 )
200+ output_dataset , start_label = 1 , keep_zeros = True , block_shape = block_shape
126201 )
127202
128- return seg_filtered , n_ids , n_ids_filtered
203+ return n_ids , n_ids_filtered
0 commit comments