Skip to content

Commit c6e2b04

Browse files
committed
Added graph connected components for postprocessing
1 parent 2fa29b1 commit c6e2b04

File tree

1 file changed

+153
-28
lines changed

1 file changed

+153
-28
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 153 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import math
12
import multiprocessing as mp
2-
import os
33
from concurrent import futures
4-
from typing import Callable, Tuple, Optional
4+
from typing import Callable, List, Optional, Tuple
55

66
import elf.parallel as parallel
77
import numpy as np
88
import nifty.tools as nt
9+
import networkx as nx
910
import pandas as pd
1011

1112
from elf.io import open_file
@@ -258,16 +259,16 @@ def erode_subset(
258259
def downscaled_centroids(
259260
table: pd.DataFrame,
260261
scale_factor: int,
261-
ref_dimensions: Optional[Tuple[float,float,float]] = None,
262-
capped: Optional[bool] = True,
262+
ref_dimensions: Optional[Tuple[float, float, float]] = None,
263+
downsample_mode: Optional[str] = "accumulated",
263264
) -> np.typing.NDArray:
264265
"""Downscale centroids in dataframe.
265266
266267
Args:
267268
table: Dataframe of segmentation table.
268269
scale_factor: Factor for downscaling coordinates.
269270
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
270-
capped: Flag for capping output of array at 1 for the creation of a binary mask.
271+
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'
271272
272273
Returns:
273274
The downscaled array
@@ -284,23 +285,35 @@ def downscaled_centroids(
284285
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
285286
new_array = np.zeros(bounding_dimensions_scaled)
286287

287-
for c in centroids_scaled:
288-
new_array[int(c[0]), int(c[1]), int(c[2])] += 1
288+
if downsample_mode == "accumulated":
289+
for c in centroids_scaled:
290+
new_array[int(c[0]), int(c[1]), int(c[2])] += 1
289291

290-
array_downscaled = np.round(new_array).astype(int)
292+
elif downsample_mode == "capped":
293+
new_array = np.round(new_array).astype(int)
294+
new_array[new_array >= 1] = 1
291295

292-
if capped:
293-
array_downscaled[array_downscaled >= 1] = 1
296+
elif downsample_mode == "components":
297+
if "component_labels" not in table.columns:
298+
raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.")
299+
component_labels = list(table["component_labels"])
300+
for comp, centr in zip(component_labels, centroids_scaled):
301+
if comp != 0:
302+
new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp
303+
new_array = np.round(new_array).astype(int)
294304

295-
return array_downscaled
305+
else:
306+
raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.")
307+
308+
return new_array
296309

297310

298311
def coordinates_in_downscaled_blocks(
299312
table: pd.DataFrame,
300313
down_array: np.typing.NDArray,
301314
scale_factor: float,
302315
distance_component: Optional[int] = 0,
303-
) -> list:
316+
) -> List[int]:
304317
"""Checking if coordinates are within the downscaled array.
305318
306319
Args:
@@ -318,12 +331,12 @@ def coordinates_in_downscaled_blocks(
318331

319332
# check if input coordinates are within down-sampled blocks
320333
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
321-
centroids_scaled = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids]
334+
centroids = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids]
322335

323336
distance_map = distance_transform_edt(down_array == 0)
324337

325338
centroids_binary = []
326-
for c in centroids_scaled:
339+
for c in centroids:
327340
coord = (int(c[0]), int(c[1]), int(c[2]))
328341
if down_array[coord] != 0:
329342
centroids_binary.append(1)
@@ -335,13 +348,81 @@ def coordinates_in_downscaled_blocks(
335348
return centroids_binary
336349

337350

338-
def erode_sgn_seg(
351+
def erode_sgn_seg_graph(
352+
table: pd.DataFrame,
353+
keyword: Optional[str] = "distance_nn100",
354+
threshold_erode: Optional[float] = None,
355+
) -> List[List[int]]:
356+
"""Eroding the SGN segmentation.
357+
358+
Args:
359+
table: Dataframe of segmentation table.
360+
keyword: Keyword of the dataframe column for erosion.
361+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
362+
363+
Returns:
364+
Subgraph components as lists of label_ids of dataframe.
365+
"""
366+
print("initial length", len(table))
367+
distance_nn = list(table[keyword])
368+
distance_nn.sort()
369+
370+
if len(table) < 20000:
371+
iterations = 1
372+
min_cells = None
373+
average_dist = int(distance_nn[int(len(table) * 0.8)])
374+
threshold = threshold_erode if threshold_erode is not None else average_dist
375+
else:
376+
iterations = 15
377+
min_cells = 20000
378+
threshold = threshold_erode if threshold_erode is not None else 40
379+
380+
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
381+
382+
new_subset = erode_subset(table.copy(), iterations=iterations,
383+
threshold=threshold, min_cells=min_cells, keyword=keyword)
384+
385+
# create graph from coordinates of eroded subset
386+
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
387+
labels_subset = [int(i) for i in list(new_subset["label_id"])]
388+
coords = {}
389+
for index, element in zip(labels_subset, centroids_subset):
390+
coords[index] = element
391+
392+
graph = nx.Graph()
393+
for num, pos in coords.items():
394+
graph.add_node(num, pos=pos)
395+
396+
# create edges between points whose distance is less than threshold
397+
threshold = 30
398+
for i in coords:
399+
for j in coords:
400+
if i < j:
401+
dist = math.dist(coords[i], coords[j])
402+
if dist <= threshold:
403+
graph.add_edge(i, j, weight=dist)
404+
405+
components = list(nx.connected_components(graph))
406+
407+
# remove connected components with less nodes than threshold
408+
min_length = 100
409+
for component in components:
410+
if len(component) < min_length:
411+
for c in component:
412+
graph.remove_node(c)
413+
414+
components = list(nx.connected_components(graph))
415+
416+
return components
417+
418+
419+
def erode_sgn_seg_downscaling(
339420
table: pd.DataFrame,
340421
keyword: Optional[str] = "distance_nn100",
341422
filter_small_components: Optional[int] = None,
342423
scale_factor: Optional[float] = 20,
343424
threshold_erode: Optional[float] = None,
344-
) -> Tuple[pd.DataFrame,np.typing.NDArray,np.typing.NDArray,np.typing.NDArray]:
425+
) -> Tuple[np.typing.NDArray, np.typing.NDArray]:
345426
"""Eroding the SGN segmentation.
346427
347428
Args:
@@ -355,7 +436,6 @@ def erode_sgn_seg(
355436
The labeled components of the downscaled, eroded coordinates.
356437
The larget connected component of the labeled components.
357438
"""
358-
359439
ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
360440
print("initial length", len(table))
361441
distance_nn = list(table[keyword])
@@ -375,7 +455,9 @@ def erode_sgn_seg(
375455

376456
new_subset = erode_subset(table.copy(), iterations=iterations,
377457
threshold=threshold, min_cells=min_cells, keyword=keyword)
458+
378459
eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions)
460+
379461
# Label connected components
380462
labeled, num_features = label(eroded_arr)
381463

@@ -387,7 +469,7 @@ def erode_sgn_seg(
387469
largest_component = (labeled == largest_label).astype(np.uint8)
388470
largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8)
389471

390-
#filter small sizes
472+
# filter small sizes
391473
if filter_small_components is not None:
392474
for (size, feature) in zip(sizes, range(1, num_features + 1)):
393475
if size < filter_small_components:
@@ -396,11 +478,12 @@ def erode_sgn_seg(
396478
return labeled, largest_component_filtered
397479

398480

399-
def get_components(table: pd.DataFrame,
481+
def get_components(
482+
table: pd.DataFrame,
400483
labeled: np.typing.NDArray,
401484
scale_factor: float,
402485
distance_component: Optional[int] = 0,
403-
) -> list:
486+
) -> List[int]:
404487
"""Indexing coordinates according to labeled array.
405488
406489
Args:
@@ -423,29 +506,71 @@ def get_components(table: pd.DataFrame,
423506
for label_index, l in enumerate(unique_labels):
424507
label_arr = (labeled == l).astype(np.uint8)
425508
centroids_binary = coordinates_in_downscaled_blocks(table, label_arr,
426-
scale_factor, distance_component = distance_component)
509+
scale_factor, distance_component=distance_component)
427510
for num, c in enumerate(centroids_binary):
428511
if c != 0:
429512
component_labels[num] = label_index + 1
430513

431514
return component_labels
432515

433516

434-
def postprocess_sgn_seg(table: pd.DataFrame, scale_factor: Optional[float] = 20) -> pd.DataFrame:
517+
def component_labels_graph(table: pd.DataFrame) -> List[int]:
518+
"""Label components using graph connected components.
519+
520+
Args:
521+
table: Dataframe of segmentation table.
522+
523+
Returns:
524+
List of component label for each point in dataframe.
525+
"""
526+
components = erode_sgn_seg_graph(table)
527+
528+
length_components = [len(c) for c in components]
529+
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
530+
531+
component_labels = [0 for _ in range(len(table))]
532+
for lab, comp in enumerate(components):
533+
for comp_index in comp:
534+
component_labels[comp_index] = lab + 1
535+
536+
return component_labels
537+
538+
539+
def component_labels_downscaling(table: pd.DataFrame, scale_factor: float = 20) -> List[int]:
540+
"""Label components using downscaling and connected components.
541+
542+
Args:
543+
table: Dataframe of segmentation table.
544+
scale_factor: Factor for downscaling.
545+
546+
Returns:
547+
List of component label for each point in dataframe.
548+
"""
549+
labeled, largest_component = erode_sgn_seg_downscaling(table, filter_small_components=10,
550+
scale_factor=scale_factor, threshold_erode=None)
551+
component_labels = get_components(table, labeled, scale_factor, distance_component=1)
552+
553+
return component_labels
554+
555+
556+
def postprocess_sgn_seg(
557+
table: pd.DataFrame,
558+
postprocess_type: Optional[str] = "downsampling",
559+
) -> pd.DataFrame:
435560
"""Postprocessing SGN segmentation of cochlea.
436561
437562
Args:
438563
table: Dataframe of segmentation table.
439-
scale_factor: Scaling for downsampling.
564+
postprocess_type: Postprocessing method, either 'downsampling' or 'graph'.
440565
441566
Returns:
442567
Dataframe with component labels.
443568
"""
444-
labeled, largest_component = erode_sgn_seg(table, filter_small_labels=10,
445-
scale_factor=scale_factor, threshold_erode=None)
446-
447-
component_labels = get_components(table, labeled, scale_factor, distance_component = 1)
569+
if postprocess_type == "downsampling":
570+
component_labels = component_labels_downscaling(table)
571+
elif postprocess_type == "graph":
572+
component_labels = component_labels_graph(table)
448573

449574
table.loc[:, "component_labels"] = component_labels
450575

451-
return table
576+
return table

0 commit comments

Comments
 (0)