Skip to content

Commit f2355d5

Browse files
committed
Initial functions for postprocessing IHC segmentation
1 parent 534ab38 commit f2355d5

File tree

1 file changed

+117
-25
lines changed

1 file changed

+117
-25
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,31 @@ def downscaled_centroids(
319319
return new_array
320320

321321

322+
def graph_connected_components(coords, min_edge_distance, min_component_length):
323+
graph = nx.Graph()
324+
for num, pos in coords.items():
325+
graph.add_node(num, pos=pos)
326+
327+
# create edges between points whose distance is less than threshold min_edge_distance
328+
for i in coords:
329+
for j in coords:
330+
if i < j:
331+
dist = math.dist(coords[i], coords[j])
332+
if dist <= min_edge_distance:
333+
graph.add_edge(i, j, weight=dist)
334+
335+
components = list(nx.connected_components(graph))
336+
337+
# remove connected components with less nodes than threshold min_component_length
338+
for component in components:
339+
if len(component) < min_component_length:
340+
for c in component:
341+
graph.remove_node(c)
342+
343+
components = [list(s) for s in nx.connected_components(graph)]
344+
return components
345+
346+
322347
def components_sgn(
323348
table: pd.DataFrame,
324349
keyword: str = "distance_nn100",
@@ -370,27 +395,7 @@ def components_sgn(
370395
for index, element in zip(labels_subset, centroids_subset):
371396
coords[index] = element
372397

373-
graph = nx.Graph()
374-
for num, pos in coords.items():
375-
graph.add_node(num, pos=pos)
376-
377-
# create edges between points whose distance is less than threshold min_edge_distance
378-
for i in coords:
379-
for j in coords:
380-
if i < j:
381-
dist = math.dist(coords[i], coords[j])
382-
if dist <= min_edge_distance:
383-
graph.add_edge(i, j, weight=dist)
384-
385-
components = list(nx.connected_components(graph))
386-
387-
# remove connected components with less nodes than threshold min_component_length
388-
for component in components:
389-
if len(component) < min_component_length:
390-
for c in component:
391-
graph.remove_node(c)
392-
393-
components = [list(s) for s in nx.connected_components(graph)]
398+
components = graph_connected_components(coords, min_edge_distance, min_component_length)
394399

395400
# add original coordinates closer to eroded component than threshold
396401
if postprocess_graph:
@@ -410,7 +415,7 @@ def components_sgn(
410415
return components
411416

412417

413-
def label_components(
418+
def label_components_sgn(
414419
table: pd.DataFrame,
415420
min_size: int = 1000,
416421
threshold_erode: Optional[float] = None,
@@ -477,9 +482,96 @@ def postprocess_sgn_seg(
477482
Dataframe with component labels.
478483
"""
479484

480-
comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
481-
min_component_length=min_component_length,
482-
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
485+
comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode,
486+
min_component_length=min_component_length,
487+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
488+
489+
table.loc[:, "component_labels"] = comp_labels
490+
491+
return table
492+
493+
494+
def components_ihc(
495+
table: pd.DataFrame,
496+
min_component_length: int = 50,
497+
min_edge_distance: float = 30,
498+
):
499+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
500+
labels = [int(i) for i in list(table["label_id"])]
501+
coords = {}
502+
for index, element in zip(labels, centroids):
503+
coords[index] = element
504+
505+
components = graph_connected_components(coords, min_edge_distance, min_component_length)
506+
return components
507+
508+
509+
def label_components_ihc(
510+
table: pd.DataFrame,
511+
min_size: int = 1000,
512+
min_component_length: int = 50,
513+
min_edge_distance: float = 30,
514+
) -> List[int]:
515+
"""Label components using graph connected components.
516+
517+
Args:
518+
table: Dataframe of segmentation table.
519+
min_size: Minimal number of pixels for filtering small instances.
520+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
521+
min_component_length: Minimal length for filtering out connected components.
522+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
523+
iterations_erode: Number of iterations for erosion, normally determined automatically.
524+
525+
Returns:
526+
List of component label for each point in dataframe. 0 - background, then in descending order of size
527+
"""
528+
529+
# First, apply the size filter.
530+
entries_filtered = table[table.n_pixels < min_size]
531+
table = table[table.n_pixels >= min_size]
532+
533+
components = components_ihc(table, min_component_length=min_component_length,
534+
min_edge_distance=min_edge_distance)
535+
536+
# add size-filtered objects to have same initial length
537+
table = pd.concat([table, entries_filtered], ignore_index=True)
538+
table.sort_values("label_id")
539+
540+
length_components = [len(c) for c in components]
541+
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
542+
543+
component_labels = [0 for _ in range(len(table))]
544+
# be aware of 'label_id' of dataframe starting at 1
545+
for lab, comp in enumerate(components):
546+
for comp_index in comp:
547+
component_labels[comp_index - 1] = lab + 1
548+
549+
return component_labels
550+
551+
552+
def postprocess_ihc_seg(
553+
table: pd.DataFrame,
554+
min_size: int = 1000,
555+
min_component_length: int = 50,
556+
min_edge_distance: float = 30,
557+
) -> pd.DataFrame:
558+
"""Postprocessing SGN segmentation of cochlea.
559+
560+
Args:
561+
table: Dataframe of segmentation table.
562+
min_size: Minimal number of pixels for filtering small instances.
563+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
564+
min_component_length: Minimal length for filtering out connected components.
565+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
566+
iterations_erode: Number of iterations for erosion, normally determined automatically.
567+
568+
Returns:
569+
Dataframe with component labels.
570+
"""
571+
572+
comp_labels = label_components_ihc(table, min_size=min_size,
573+
min_component_length=min_component_length,
574+
min_edge_distance=min_edge_distance)
483575

484576
table.loc[:, "component_labels"] = comp_labels
485577

0 commit comments

Comments
 (0)