11import multiprocessing as mp
2+ import os
23from concurrent import futures
34from typing import Callable , Tuple , Optional
45
89import pandas as pd
910
1011from elf .io import open_file
11- from scipy .spatial import distance
12+ from scipy .ndimage import binary_fill_holes
13+ from scipy .ndimage import distance_transform_edt
14+ from scipy .ndimage import label
1215from scipy .sparse import csr_matrix
16+ from scipy .spatial import distance
1317from scipy .spatial import cKDTree , ConvexHull
1418from skimage import measure
1519from sklearn .neighbors import NearestNeighbors
@@ -205,3 +209,237 @@ def filter_chunk(block_id):
205209 )
206210
207211 return n_ids , n_ids_filtered
212+
213+
214+ # Postprocess segmentation by erosion using the above spatial statistics.
215+ # Currently implemented using downscaling and looking for connected components
216+ # TODO: Change implementation to graph connected components.
217+
218+
219+ def erode_subset (
220+ table : pd .DataFrame ,
221+ iterations : Optional [int ] = 1 ,
222+ min_cells : Optional [int ] = None ,
223+ threshold : Optional [int ] = 35 ,
224+ keyword : Optional [str ] = "distance_nn100" ,
225+ ) -> pd .DataFrame :
226+ """Erode coordinates of dataframe according to a keyword and a threshold.
227+ Use a copy of the dataframe as an input, if it should not be edited.
228+
229+ Args:
230+ table: Dataframe of segmentation table.
231+ iterations: Number of steps for erosion process.
232+ min_cells: Minimal number of rows. The erosion is stopped before reaching this number.
233+ threshold: Upper threshold for removing elements according to the given keyword.
234+ keyword: Keyword of dataframe for erosion.
235+
236+ Returns:
237+ The dataframe containing elements left after the erosion.
238+ """
239+ print ("initial length" , len (table ))
240+ n_neighbors = 100
241+ for i in range (iterations ):
242+ table = table [table [keyword ] < threshold ]
243+
244+ # TODO: support other spatial statistics
245+ distance_avg = nearest_neighbor_distance (table , n_neighbors = n_neighbors )
246+
247+ if min_cells is not None and len (distance_avg ) < min_cells :
248+ print (f"{ i } -th iteration, length of subset { len (table )} , stopping erosion" )
249+ break
250+
251+ table .loc [:, 'distance_nn' + str (n_neighbors )] = list (distance_avg )
252+
253+ print (f"{ i } -th iteration, length of subset { len (table )} " )
254+
255+ return table
256+
257+
258+ def downscaled_centroids (
259+ table : pd .DataFrame ,
260+ scale_factor : int ,
261+ ref_dimensions : Optional [Tuple [float ,float ,float ]] = None ,
262+ capped : Optional [bool ] = True ,
263+ ) -> np .typing .NDArray :
264+ """Downscale centroids in dataframe.
265+
266+ Args:
267+ table: Dataframe of segmentation table.
268+ scale_factor: Factor for downscaling coordinates.
269+ ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
270+ capped: Flag for capping output of array at 1 for the creation of a binary mask.
271+
272+ Returns:
273+ The downscaled array
274+ """
275+ centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
276+ centroids_scaled = [(c [0 ] / scale_factor , c [1 ] / scale_factor , c [2 ] / scale_factor ) for c in centroids ]
277+
278+ if ref_dimensions is None :
279+ bounding_dimensions = (max (table ["anchor_x" ]), max (table ["anchor_y" ]), max (table ["anchor_z" ]))
280+ bounding_dimensions_scaled = tuple ([round (b // scale_factor + 1 ) for b in bounding_dimensions ])
281+ new_array = np .zeros (bounding_dimensions_scaled )
282+
283+ else :
284+ bounding_dimensions_scaled = tuple ([round (b // scale_factor + 1 ) for b in ref_dimensions ])
285+ new_array = np .zeros (bounding_dimensions_scaled )
286+
287+ for c in centroids_scaled :
288+ new_array [int (c [0 ]), int (c [1 ]), int (c [2 ])] += 1
289+
290+ array_downscaled = np .round (new_array ).astype (int )
291+
292+ if capped :
293+ array_downscaled [array_downscaled >= 1 ] = 1
294+
295+ return array_downscaled
296+
297+
298+ def coordinates_in_downscaled_blocks (
299+ table : pd .DataFrame ,
300+ down_array : np .typing .NDArray ,
301+ scale_factor : float ,
302+ distance_component : Optional [int ] = 0 ,
303+ ) -> list :
304+ """Checking if coordinates are within the downscaled array.
305+
306+ Args:
307+ table: Dataframe of segmentation table.
308+ down_array: Downscaled array.
309+ scale_factor: Factor which was used for downscaling.
310+ distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
311+
312+ Returns:
313+ A binary list representing whether the dataframe coordinates are within the array.
314+ """
315+ # fill holes in down-sampled array
316+ down_array [down_array > 0 ] = 1
317+ down_array = binary_fill_holes (down_array ).astype (np .uint8 )
318+
319+ # check if input coordinates are within down-sampled blocks
320+ centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
321+ centroids_scaled = [np .floor (np .array ([c [0 ]/ scale_factor , c [1 ]/ scale_factor , c [2 ]/ scale_factor ])) for c in centroids ]
322+
323+ distance_map = distance_transform_edt (down_array == 0 )
324+
325+ centroids_binary = []
326+ for c in centroids_scaled :
327+ coord = (int (c [0 ]), int (c [1 ]), int (c [2 ]))
328+ if down_array [coord ] != 0 :
329+ centroids_binary .append (1 )
330+ elif distance_map [coord ] <= distance_component :
331+ centroids_binary .append (1 )
332+ else :
333+ centroids_binary .append (0 )
334+
335+ return centroids_binary
336+
337+
338+ def erode_sgn_seg (
339+ table : pd .DataFrame ,
340+ keyword : Optional [str ] = "distance_nn100" ,
341+ filter_small_components : Optional [int ] = None ,
342+ scale_factor : Optional [float ] = 20 ,
343+ threshold_erode : Optional [float ] = None ,
344+ ) -> Tuple [pd .DataFrame ,np .typing .NDArray ,np .typing .NDArray ,np .typing .NDArray ]:
345+ """Eroding the SGN segmentation.
346+
347+ Args:
348+ table: Dataframe of segmentation table.
349+ keyword: Keyword of the dataframe column for erosion.
350+ filter_small_components: Filter components smaller after n blocks after labeling.
351+ scale_factor: Scaling for downsampling.
352+ threshold_erode: Threshold of column value after erosion step with spatial statistics.
353+
354+ Returns:
355+ The labeled components of the downscaled, eroded coordinates.
356+ The larget connected component of the labeled components.
357+ """
358+
359+ ref_dimensions = (max (table ["anchor_x" ]), max (table ["anchor_y" ]), max (table ["anchor_z" ]))
360+ print ("initial length" , len (table ))
361+ distance_nn = list (table [keyword ])
362+ distance_nn .sort ()
363+
364+ if len (table ) < 20000 :
365+ iterations = 1
366+ min_cells = None
367+ average_dist = int (distance_nn [int (len (table ) * 0.8 )])
368+ threshold = threshold_erode if threshold_erode is not None else average_dist
369+ else :
370+ iterations = 15
371+ min_cells = 20000
372+ threshold = threshold_erode if threshold_erode is not None else 40
373+
374+ print (f"Using threshold of { threshold } micrometer for eroding segmentation with keyword { keyword } ." )
375+
376+ new_subset = erode_subset (table .copy (), iterations = iterations ,
377+ threshold = threshold , min_cells = min_cells , keyword = keyword )
378+ eroded_arr = downscaled_centroids (new_subset , scale_factor = scale_factor , ref_dimensions = ref_dimensions )
379+ # Label connected components
380+ labeled , num_features = label (eroded_arr )
381+
382+ # Find the largest component
383+ sizes = [(labeled == i ).sum () for i in range (1 , num_features + 1 )]
384+ largest_label = np .argmax (sizes ) + 1
385+
386+ # Extract only the largest component
387+ largest_component = (labeled == largest_label ).astype (np .uint8 )
388+ largest_component_filtered = binary_fill_holes (largest_component ).astype (np .uint8 )
389+
390+ #filter small sizes
391+ if filter_small_components is not None :
392+ for (size , feature ) in zip (sizes , range (1 , num_features + 1 )):
393+ if size < filter_small_components :
394+ labeled [labeled == feature ] = 0
395+
396+ return labeled , largest_component_filtered
397+
398+
399+ def get_components (table : pd .DataFrame ,
400+ labeled : np .typing .NDArray ,
401+ scale_factor : float ,
402+ distance_component : Optional [int ] = 0 ,
403+ ) -> list :
404+ """Indexing coordinates according to labeled array.
405+
406+ Args:
407+ table: Dataframe of segmentation table.
408+ labeled: Array containing differently labeled components.
409+ scale_factor: Scaling for downsampling.
410+ distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
411+
412+ Returns:
413+ List of component labels.
414+ """
415+ unique_labels = list (np .unique (labeled ))
416+ component_labels = [0 for _ in range (len (table ))]
417+ for label_index , l in enumerate (unique_labels ):
418+ if l != 0 :
419+ label_arr = (labeled == l ).astype (np .uint8 )
420+ centroids_binary = coordinates_in_downscaled_blocks (table , label_arr ,
421+ scale_factor , distance_component = distance_component )
422+ for num , c in enumerate (centroids_binary ):
423+ if c != 0 :
424+ component_labels [num ] = label_index
425+ return component_labels
426+
427+
428+ def postprocess_sgn_seg (table : pd .DataFrame , scale_factor : Optional [float ] = 20 ) -> pd .DataFrame :
429+ """Postprocessing SGN segmentation of cochlea.
430+
431+ Args:
432+ table: Dataframe of segmentation table.
433+ scale_factor: Scaling for downsampling.
434+
435+ Returns:
436+ Dataframe with component labels.
437+ """
438+ labeled , largest_component = erode_sgn_seg (table , filter_small_labels = 10 ,
439+ scale_factor = scale_factor , threshold_erode = None )
440+
441+ component_labels = get_components (table , labeled , scale_factor , distance_component = 1 )
442+
443+ table .loc [:, "component_labels" ] = component_labels
444+
445+ return table
0 commit comments