1010import pandas as pd
1111
1212from elf .io import open_file
13- from scipy .ndimage import binary_fill_holes
14- from scipy .ndimage import distance_transform_edt
15- from scipy .ndimage import label
1613from scipy .sparse import csr_matrix
1714from scipy .spatial import distance
1815from scipy .spatial import cKDTree , ConvexHull
@@ -212,11 +209,6 @@ def filter_chunk(block_id):
212209 return n_ids , n_ids_filtered
213210
214211
215- # Postprocess segmentation by erosion using the above spatial statistics.
216- # Currently implemented using downscaling and looking for connected components
217- # TODO: Change implementation to graph connected components.
218-
219-
220212def erode_subset (
221213 table : pd .DataFrame ,
222214 iterations : Optional [int ] = 1 ,
@@ -242,7 +234,6 @@ def erode_subset(
242234 for i in range (iterations ):
243235 table = table [table [keyword ] < threshold ]
244236
245- # TODO: support other spatial statistics
246237 distance_avg = nearest_neighbor_distance (table , n_neighbors = n_neighbors )
247238
248239 if min_cells is not None and len (distance_avg ) < min_cells :
@@ -309,72 +300,43 @@ def downscaled_centroids(
309300 return new_array
310301
311302
312- def coordinates_in_downscaled_blocks (
313- table : pd .DataFrame ,
314- down_array : np .typing .NDArray ,
315- scale_factor : float ,
316- distance_component : Optional [int ] = 0 ,
317- ) -> List [int ]:
318- """Checking if coordinates are within the downscaled array.
319-
320- Args:
321- table: Dataframe of segmentation table.
322- down_array: Downscaled array.
323- scale_factor: Factor which was used for downscaling.
324- distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
325-
326- Returns:
327- A binary list representing whether the dataframe coordinates are within the array.
328- """
329- # fill holes in down-sampled array
330- down_array [down_array > 0 ] = 1
331- down_array = binary_fill_holes (down_array ).astype (np .uint8 )
332-
333- # check if input coordinates are within down-sampled blocks
334- centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
335- centroids = [np .floor (np .array ([c [0 ]/ scale_factor , c [1 ]/ scale_factor , c [2 ]/ scale_factor ])) for c in centroids ]
336-
337- distance_map = distance_transform_edt (down_array == 0 )
338-
339- centroids_binary = []
340- for c in centroids :
341- coord = (int (c [0 ]), int (c [1 ]), int (c [2 ]))
342- if down_array [coord ] != 0 :
343- centroids_binary .append (1 )
344- elif distance_map [coord ] <= distance_component :
345- centroids_binary .append (1 )
346- else :
347- centroids_binary .append (0 )
348-
349- return centroids_binary
350-
351-
352- def erode_sgn_seg_graph (
303+ def components_sgn (
353304 table : pd .DataFrame ,
354305 keyword : Optional [str ] = "distance_nn100" ,
355306 threshold_erode : Optional [float ] = None ,
307+ postprocess_graph : Optional [bool ] = False ,
308+ min_component_length : Optional [int ] = 50 ,
309+ min_edge_distance : Optional [float ] = 30 ,
310+ iterations_erode : Optional [int ] = None ,
356311) -> List [List [int ]]:
357312 """Eroding the SGN segmentation.
358313
359314 Args:
360315 table: Dataframe of segmentation table.
361316 keyword: Keyword of the dataframe column for erosion.
362317 threshold_erode: Threshold of column value after erosion step with spatial statistics.
318+ postprocess_graph: Post-process graph connected components by searching for near points.
319+ min_component_length: Minimal length for filtering out connected components.
320+ min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
321+ iterations_erode: Number of iterations for erosion, normally determined automatically.
363322
364323 Returns:
365324 Subgraph components as lists of label_ids of dataframe.
366325 """
326+ centroids = list (zip (table ["anchor_x" ], table ["anchor_y" ], table ["anchor_z" ]))
327+ labels = [int (i ) for i in list (table ["label_id" ])]
328+
367329 print ("initial length" , len (table ))
368330 distance_nn = list (table [keyword ])
369331 distance_nn .sort ()
370332
371333 if len (table ) < 20000 :
372- iterations = 1
334+ iterations = iterations_erode if iterations_erode is not None else 0
373335 min_cells = None
374336 average_dist = int (distance_nn [int (len (table ) * 0.8 )])
375337 threshold = threshold_erode if threshold_erode is not None else average_dist
376338 else :
377- iterations = 15
339+ iterations = iterations_erode if iterations_erode is not None else 15
378340 min_cells = 20000
379341 threshold = threshold_erode if threshold_erode is not None else 40
380342
@@ -394,183 +356,88 @@ def erode_sgn_seg_graph(
394356 for num , pos in coords .items ():
395357 graph .add_node (num , pos = pos )
396358
397- # create edges between points whose distance is less than threshold
398- threshold = 30
359+ # create edges between points whose distance is less than threshold min_edge_distance
399360 for i in coords :
400361 for j in coords :
401362 if i < j :
402363 dist = math .dist (coords [i ], coords [j ])
403- if dist <= threshold :
364+ if dist <= min_edge_distance :
404365 graph .add_edge (i , j , weight = dist )
405366
406367 components = list (nx .connected_components (graph ))
407368
408- # remove connected components with less nodes than threshold
409- min_length = 100
369+ # remove connected components with less nodes than threshold min_component_length
410370 for component in components :
411- if len (component ) < min_length :
371+ if len (component ) < min_component_length :
412372 for c in component :
413373 graph .remove_node (c )
414374
415- components = list (nx .connected_components (graph ))
375+ components = [list (s ) for s in nx .connected_components (graph )]
376+
377+ # add original coordinates closer to eroded component than threshold
378+ if postprocess_graph :
379+ threshold = 15
380+ for label_id , centr in zip (labels , centroids ):
381+ if label_id not in labels_subset :
382+ add_coord = []
383+ for comp_index , component in enumerate (components ):
384+ for comp_label in component :
385+ dist = math .dist (centr , centroids [comp_label - 1 ])
386+ if dist <= threshold :
387+ add_coord .append ([comp_index , label_id ])
388+ break
389+ if len (add_coord ) != 0 :
390+ components [add_coord [0 ][0 ]].append (add_coord [0 ][1 ])
416391
417392 return components
418393
419394
420- def erode_sgn_seg_downscaling (
395+ def label_components (
421396 table : pd .DataFrame ,
422- keyword : Optional [str ] = "distance_nn100" ,
423- filter_small_components : Optional [int ] = None ,
424- scale_factor : Optional [float ] = 20 ,
425397 threshold_erode : Optional [float ] = None ,
426- ) -> Tuple [np .typing .NDArray , np .typing .NDArray ]:
427- """Eroding the SGN segmentation.
428-
429- Args:
430- table: Dataframe of segmentation table.
431- keyword: Keyword of the dataframe column for erosion.
432- filter_small_components: Filter components smaller after n blocks after labeling.
433- scale_factor: Scaling for downsampling.
434- threshold_erode: Threshold of column value after erosion step with spatial statistics.
435-
436- Returns:
437- The labeled components of the downscaled, eroded coordinates.
438- The larget connected component of the labeled components.
439- """
440- ref_dimensions = (max (table ["anchor_x" ]), max (table ["anchor_y" ]), max (table ["anchor_z" ]))
441- print ("initial length" , len (table ))
442- distance_nn = list (table [keyword ])
443- distance_nn .sort ()
444-
445- if len (table ) < 20000 :
446- iterations = 1
447- min_cells = None
448- average_dist = int (distance_nn [int (len (table ) * 0.8 )])
449- threshold = threshold_erode if threshold_erode is not None else average_dist
450- else :
451- iterations = 15
452- min_cells = 20000
453- threshold = threshold_erode if threshold_erode is not None else 40
454-
455- print (f"Using threshold of { threshold } micrometer for eroding segmentation with keyword { keyword } ." )
456-
457- new_subset = erode_subset (table .copy (), iterations = iterations ,
458- threshold = threshold , min_cells = min_cells , keyword = keyword )
459-
460- eroded_arr = downscaled_centroids (new_subset , scale_factor = scale_factor , ref_dimensions = ref_dimensions )
461-
462- # Label connected components
463- labeled , num_features = label (eroded_arr )
464-
465- # Find the largest component
466- sizes = [(labeled == i ).sum () for i in range (1 , num_features + 1 )]
467- largest_label = np .argmax (sizes ) + 1
468-
469- # Extract only the largest component
470- largest_component = (labeled == largest_label ).astype (np .uint8 )
471- largest_component_filtered = binary_fill_holes (largest_component ).astype (np .uint8 )
472-
473- # filter small sizes
474- if filter_small_components is not None :
475- for (size , feature ) in zip (sizes , range (1 , num_features + 1 )):
476- if size < filter_small_components :
477- labeled [labeled == feature ] = 0
478-
479- return labeled , largest_component_filtered
480-
481-
482- def get_components (
483- table : pd .DataFrame ,
484- labeled : np .typing .NDArray ,
485- scale_factor : float ,
486- distance_component : Optional [int ] = 0 ,
398+ min_component_length : Optional [int ] = 50 ,
399+ min_edge_distance : Optional [float ] = 30 ,
400+ iterations_erode : Optional [int ] = None ,
487401) -> List [int ]:
488- """Indexing coordinates according to labeled array.
489-
490- Args:
491- table: Dataframe of segmentation table.
492- labeled: Array containing differently labeled components.
493- scale_factor: Scaling for downsampling.
494- distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
495-
496- Returns:
497- List of component labels.
498- """
499- unique_labels = list (np .unique (labeled ))
500-
501- # sort non-background labels according to size, descending
502- unique_labels = [i for i in unique_labels if i != 0 ]
503- sizes = [(labeled == i ).sum () for i in unique_labels ]
504- sizes , unique_labels = zip (* sorted (zip (sizes , unique_labels ), reverse = True ))
505-
506- component_labels = [0 for _ in range (len (table ))]
507- for label_index , l in enumerate (unique_labels ):
508- label_arr = (labeled == l ).astype (np .uint8 )
509- centroids_binary = coordinates_in_downscaled_blocks (table , label_arr ,
510- scale_factor , distance_component = distance_component )
511- for num , c in enumerate (centroids_binary ):
512- if c != 0 :
513- component_labels [num ] = label_index + 1
514-
515- return component_labels
516-
517-
518- def component_labels_graph (table : pd .DataFrame ) -> List [int ]:
519402 """Label components using graph connected components.
520403
521404 Args:
522405 table: Dataframe of segmentation table.
406+ threshold_erode: Threshold of column value after erosion step with spatial statistics.
407+ min_component_length: Minimal length for filtering out connected components.
408+ min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
409+ iterations_erode: Number of iterations for erosion, normally determined automatically.
523410
524411 Returns:
525- List of component label for each point in dataframe.
412+ List of component label for each point in dataframe. 0 - background, then in descending order of size
526413 """
527- components = erode_sgn_seg_graph (table )
414+ components = components_sgn (table , threshold_erode = threshold_erode , min_component_length = min_component_length ,
415+ min_edge_distance = min_edge_distance , iterations_erode = iterations_erode )
528416
529417 length_components = [len (c ) for c in components ]
530418 length_components , components = zip (* sorted (zip (length_components , components ), reverse = True ))
531419
532420 component_labels = [0 for _ in range (len (table ))]
421+ # be aware of 'label_id' of dataframe starting at 1
533422 for lab , comp in enumerate (components ):
534423 for comp_index in comp :
535424 component_labels [comp_index - 1 ] = lab + 1
536425
537426 return component_labels
538427
539428
540- def component_labels_downscaling (table : pd .DataFrame , scale_factor : float = 20 ) -> List [int ]:
541- """Label components using downscaling and connected components.
542-
543- Args:
544- table: Dataframe of segmentation table.
545- scale_factor: Factor for downscaling.
546-
547- Returns:
548- List of component label for each point in dataframe.
549- """
550- labeled , largest_component = erode_sgn_seg_downscaling (table , filter_small_components = 10 ,
551- scale_factor = scale_factor , threshold_erode = None )
552- component_labels = get_components (table , labeled , scale_factor , distance_component = 1 )
553-
554- return component_labels
555-
556-
557429def postprocess_sgn_seg (
558430 table : pd .DataFrame ,
559- postprocess_type : Optional [str ] = "downsampling" ,
560431) -> pd .DataFrame :
561432 """Postprocessing SGN segmentation of cochlea.
562433
563434 Args:
564435 table: Dataframe of segmentation table.
565- postprocess_type: Postprocessing method, either 'downsampling' or 'graph'.
566436
567437 Returns:
568438 Dataframe with component labels.
569439 """
570- if postprocess_type == "downsampling" :
571- component_labels = component_labels_downscaling (table )
572- elif postprocess_type == "graph" :
573- component_labels = component_labels_graph (table )
440+ component_labels = label_components (table )
574441
575442 table .loc [:, "component_labels" ] = component_labels
576443
0 commit comments