1+ import math
12import multiprocessing as mp
2- import os
33from concurrent import futures
4- from typing import Callable , Tuple , Optional
4+ from typing import Callable , List , Optional , Tuple
55
66import elf .parallel as parallel
77import numpy as np
88import nifty .tools as nt
9+ import networkx as nx
910import pandas as pd
1011
1112from elf .io import open_file
@@ -258,16 +259,16 @@ def erode_subset(
258259def downscaled_centroids (
259260 table : pd .DataFrame ,
260261 scale_factor : int ,
261- ref_dimensions : Optional [Tuple [float ,float ,float ]] = None ,
262- capped : Optional [bool ] = True ,
262+ ref_dimensions : Optional [Tuple [float , float , float ]] = None ,
263+ downsample_mode : Optional [str ] = "accumulated" ,
263264) -> np .typing .NDArray :
264265 """Downscale centroids in dataframe.
265266
266267 Args:
267268 table: Dataframe of segmentation table.
268269 scale_factor: Factor for downscaling coordinates.
269270 ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
270- capped : Flag for capping output of array at 1 for the creation of a binary mask.
271+ downsample_mode : Flag for downsampling, either 'accumulated', 'capped', or 'components'
271272
272273 Returns:
273274 The downscaled array
@@ -284,23 +285,35 @@ def downscaled_centroids(
284285 bounding_dimensions_scaled = tuple ([round (b // scale_factor + 1 ) for b in ref_dimensions ])
285286 new_array = np .zeros (bounding_dimensions_scaled )
286287
287- for c in centroids_scaled :
288- new_array [int (c [0 ]), int (c [1 ]), int (c [2 ])] += 1
288+ if downsample_mode == "accumulated" :
289+ for c in centroids_scaled :
290+ new_array [int (c [0 ]), int (c [1 ]), int (c [2 ])] += 1
289291
290- array_downscaled = np .round (new_array ).astype (int )
292+ elif downsample_mode == "capped" :
293+ new_array = np .round (new_array ).astype (int )
294+ new_array [new_array >= 1 ] = 1
291295
292- if capped :
293- array_downscaled [array_downscaled >= 1 ] = 1
296+ elif downsample_mode == "components" :
297+ if "component_labels" not in table .columns :
298+ raise KeyError ("Dataframe must continue key 'component_labels' for downsampling with mode 'components'." )
299+ component_labels = list (table ["component_labels" ])
300+ for comp , centr in zip (component_labels , centroids_scaled ):
301+ if comp != 0 :
302+ new_array [int (centr [0 ]), int (centr [1 ]), int (centr [2 ])] = comp
303+ new_array = np .round (new_array ).astype (int )
294304
295- return array_downscaled
305+ else :
306+ raise ValueError ("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'." )
307+
308+ return new_array
296309
297310
298311def coordinates_in_downscaled_blocks (
299312 table : pd .DataFrame ,
300313 down_array : np .typing .NDArray ,
301314 scale_factor : float ,
302315 distance_component : Optional [int ] = 0 ,
303- ) -> list :
316+ ) -> List [ int ] :
304317 """Checking if coordinates are within the downscaled array.
305318
306319 Args:
@@ -318,12 +331,12 @@ def coordinates_in_downscaled_blocks(
318331
319332 # check if input coordinates are within down-sampled blocks
320333 centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
321- centroids_scaled = [np .floor (np .array ([c [0 ]/ scale_factor , c [1 ]/ scale_factor , c [2 ]/ scale_factor ])) for c in centroids ]
334+ centroids = [np .floor (np .array ([c [0 ]/ scale_factor , c [1 ]/ scale_factor , c [2 ]/ scale_factor ])) for c in centroids ]
322335
323336 distance_map = distance_transform_edt (down_array == 0 )
324337
325338 centroids_binary = []
326- for c in centroids_scaled :
339+ for c in centroids :
327340 coord = (int (c [0 ]), int (c [1 ]), int (c [2 ]))
328341 if down_array [coord ] != 0 :
329342 centroids_binary .append (1 )
@@ -335,13 +348,81 @@ def coordinates_in_downscaled_blocks(
335348 return centroids_binary
336349
337350
338- def erode_sgn_seg (
351+ def erode_sgn_seg_graph (
352+ table : pd .DataFrame ,
353+ keyword : Optional [str ] = "distance_nn100" ,
354+ threshold_erode : Optional [float ] = None ,
355+ ) -> List [List [int ]]:
356+ """Eroding the SGN segmentation.
357+
358+ Args:
359+ table: Dataframe of segmentation table.
360+ keyword: Keyword of the dataframe column for erosion.
361+ threshold_erode: Threshold of column value after erosion step with spatial statistics.
362+
363+ Returns:
364+ Subgraph components as lists of label_ids of dataframe.
365+ """
366+ print ("initial length" , len (table ))
367+ distance_nn = list (table [keyword ])
368+ distance_nn .sort ()
369+
370+ if len (table ) < 20000 :
371+ iterations = 1
372+ min_cells = None
373+ average_dist = int (distance_nn [int (len (table ) * 0.8 )])
374+ threshold = threshold_erode if threshold_erode is not None else average_dist
375+ else :
376+ iterations = 15
377+ min_cells = 20000
378+ threshold = threshold_erode if threshold_erode is not None else 40
379+
380+ print (f"Using threshold of { threshold } micrometer for eroding segmentation with keyword { keyword } ." )
381+
382+ new_subset = erode_subset (table .copy (), iterations = iterations ,
383+ threshold = threshold , min_cells = min_cells , keyword = keyword )
384+
385+ # create graph from coordinates of eroded subset
386+ centroids_subset = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
387+ labels_subset = [int (i ) for i in list (new_subset ["label_id" ])]
388+ coords = {}
389+ for index , element in zip (labels_subset , centroids_subset ):
390+ coords [index ] = element
391+
392+ graph = nx .Graph ()
393+ for num , pos in coords .items ():
394+ graph .add_node (num , pos = pos )
395+
396+ # create edges between points whose distance is less than threshold
397+ threshold = 30
398+ for i in coords :
399+ for j in coords :
400+ if i < j :
401+ dist = math .dist (coords [i ], coords [j ])
402+ if dist <= threshold :
403+ graph .add_edge (i , j , weight = dist )
404+
405+ components = list (nx .connected_components (graph ))
406+
407+ # remove connected components with less nodes than threshold
408+ min_length = 100
409+ for component in components :
410+ if len (component ) < min_length :
411+ for c in component :
412+ graph .remove_node (c )
413+
414+ components = list (nx .connected_components (graph ))
415+
416+ return components
417+
418+
419+ def erode_sgn_seg_downscaling (
339420 table : pd .DataFrame ,
340421 keyword : Optional [str ] = "distance_nn100" ,
341422 filter_small_components : Optional [int ] = None ,
342423 scale_factor : Optional [float ] = 20 ,
343424 threshold_erode : Optional [float ] = None ,
344- ) -> Tuple [pd . DataFrame , np .typing .NDArray ,np . typing . NDArray , np .typing .NDArray ]:
425+ ) -> Tuple [np .typing .NDArray , np .typing .NDArray ]:
345426 """Eroding the SGN segmentation.
346427
347428 Args:
@@ -355,7 +436,6 @@ def erode_sgn_seg(
355436 The labeled components of the downscaled, eroded coordinates.
356437 The larget connected component of the labeled components.
357438 """
358-
359439 ref_dimensions = (max (table ["anchor_x" ]), max (table ["anchor_y" ]), max (table ["anchor_z" ]))
360440 print ("initial length" , len (table ))
361441 distance_nn = list (table [keyword ])
@@ -375,7 +455,9 @@ def erode_sgn_seg(
375455
376456 new_subset = erode_subset (table .copy (), iterations = iterations ,
377457 threshold = threshold , min_cells = min_cells , keyword = keyword )
458+
378459 eroded_arr = downscaled_centroids (new_subset , scale_factor = scale_factor , ref_dimensions = ref_dimensions )
460+
379461 # Label connected components
380462 labeled , num_features = label (eroded_arr )
381463
@@ -387,7 +469,7 @@ def erode_sgn_seg(
387469 largest_component = (labeled == largest_label ).astype (np .uint8 )
388470 largest_component_filtered = binary_fill_holes (largest_component ).astype (np .uint8 )
389471
390- #filter small sizes
472+ # filter small sizes
391473 if filter_small_components is not None :
392474 for (size , feature ) in zip (sizes , range (1 , num_features + 1 )):
393475 if size < filter_small_components :
@@ -396,11 +478,12 @@ def erode_sgn_seg(
396478 return labeled , largest_component_filtered
397479
398480
399- def get_components (table : pd .DataFrame ,
481+ def get_components (
482+ table : pd .DataFrame ,
400483 labeled : np .typing .NDArray ,
401484 scale_factor : float ,
402485 distance_component : Optional [int ] = 0 ,
403- ) -> list :
486+ ) -> List [ int ] :
404487 """Indexing coordinates according to labeled array.
405488
406489 Args:
@@ -423,29 +506,71 @@ def get_components(table: pd.DataFrame,
423506 for label_index , l in enumerate (unique_labels ):
424507 label_arr = (labeled == l ).astype (np .uint8 )
425508 centroids_binary = coordinates_in_downscaled_blocks (table , label_arr ,
426- scale_factor , distance_component = distance_component )
509+ scale_factor , distance_component = distance_component )
427510 for num , c in enumerate (centroids_binary ):
428511 if c != 0 :
429512 component_labels [num ] = label_index + 1
430513
431514 return component_labels
432515
433516
434- def postprocess_sgn_seg (table : pd .DataFrame , scale_factor : Optional [float ] = 20 ) -> pd .DataFrame :
517+ def component_labels_graph (table : pd .DataFrame ) -> List [int ]:
518+ """Label components using graph connected components.
519+
520+ Args:
521+ table: Dataframe of segmentation table.
522+
523+ Returns:
524+ List of component label for each point in dataframe.
525+ """
526+ components = erode_sgn_seg_graph (table )
527+
528+ length_components = [len (c ) for c in components ]
529+ length_components , components = zip (* sorted (zip (length_components , components ), reverse = True ))
530+
531+ component_labels = [0 for _ in range (len (table ))]
532+ for lab , comp in enumerate (components ):
533+ for comp_index in comp :
534+ component_labels [comp_index ] = lab + 1
535+
536+ return component_labels
537+
538+
539+ def component_labels_downscaling (table : pd .DataFrame , scale_factor : float = 20 ) -> List [int ]:
540+ """Label components using downscaling and connected components.
541+
542+ Args:
543+ table: Dataframe of segmentation table.
544+ scale_factor: Factor for downscaling.
545+
546+ Returns:
547+ List of component label for each point in dataframe.
548+ """
549+ labeled , largest_component = erode_sgn_seg_downscaling (table , filter_small_components = 10 ,
550+ scale_factor = scale_factor , threshold_erode = None )
551+ component_labels = get_components (table , labeled , scale_factor , distance_component = 1 )
552+
553+ return component_labels
554+
555+
556+ def postprocess_sgn_seg (
557+ table : pd .DataFrame ,
558+ postprocess_type : Optional [str ] = "downsampling" ,
559+ ) -> pd .DataFrame :
435560 """Postprocessing SGN segmentation of cochlea.
436561
437562 Args:
438563 table: Dataframe of segmentation table.
439- scale_factor: Scaling for downsampling.
564+ postprocess_type: Postprocessing method, either ' downsampling' or 'graph' .
440565
441566 Returns:
442567 Dataframe with component labels.
443568 """
444- labeled , largest_component = erode_sgn_seg ( table , filter_small_labels = 10 ,
445- scale_factor = scale_factor , threshold_erode = None )
446-
447- component_labels = get_components ( table , labeled , scale_factor , distance_component = 1 )
569+ if postprocess_type == "downsampling" :
570+ component_labels = component_labels_downscaling ( table )
571+ elif postprocess_type == "graph" :
572+ component_labels = component_labels_graph ( table )
448573
449574 table .loc [:, "component_labels" ] = component_labels
450575
451- return table
576+ return table
0 commit comments