Skip to content

Commit c091e54

Browse files
committed
Added postprocessing option for SGNs
1 parent a0cc704 commit c091e54

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,21 +358,23 @@ def components_sgn(
358358
table: pd.DataFrame,
359359
keyword: str = "distance_nn100",
360360
threshold_erode: Optional[float] = None,
361-
postprocess_graph: bool = False,
362361
min_component_length: int = 50,
363362
min_edge_distance: float = 30,
364363
iterations_erode: Optional[int] = None,
364+
postprocess_threshold: Optional[float] = None,
365+
postprocess_components: Optional[List[int]] = None,
365366
) -> List[List[int]]:
366367
"""Eroding the SGN segmentation.
367368
368369
Args:
369370
table: Dataframe of segmentation table.
370371
keyword: Keyword of the dataframe column for erosion.
371372
threshold_erode: Threshold of column value after erosion step with spatial statistics.
372-
postprocess_graph: Post-process graph connected components by searching for near points.
373373
min_component_length: Minimal length for filtering out connected components.
374374
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
375375
iterations_erode: Number of iterations for erosion, normally determined automatically.
376+
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
377+
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
376378
377379
Returns:
378380
Subgraph components as lists of label_ids of dataframe.
@@ -411,20 +413,31 @@ def components_sgn(
411413

412414
components = graph_connected_components(coords, min_edge_distance, min_component_length)
413415

416+
length_components = [len(c) for c in components]
417+
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
418+
414419
# add original coordinates closer to eroded component than threshold
415-
if postprocess_graph:
416-
threshold = 15
420+
if postprocess_threshold is not None:
421+
if postprocess_components is None:
422+
pp_components = components
423+
else:
424+
pp_components = [components[i] for i in postprocess_components]
425+
426+
add_coords = []
417427
for label_id, centr in zip(labels, centroids):
418428
if label_id not in labels_subset:
419429
add_coord = []
420-
for comp_index, component in enumerate(components):
430+
for comp_index, component in enumerate(pp_components):
421431
for comp_label in component:
422432
dist = math.dist(centr, centroids[comp_label - 1])
423-
if dist <= threshold:
433+
if dist <= postprocess_threshold:
424434
add_coord.append([comp_index, label_id])
425435
break
426436
if len(add_coord) != 0:
427-
components[add_coord[0][0]].append(add_coord[0][1])
437+
add_coords.append(add_coord)
438+
if len(add_coords) != 0:
439+
for c in add_coords:
440+
components[c[0][0]].append(c[0][1])
428441

429442
return components
430443

@@ -436,6 +449,8 @@ def label_components_sgn(
436449
min_component_length: int = 50,
437450
min_edge_distance: float = 30,
438451
iterations_erode: Optional[int] = None,
452+
postprocess_threshold: Optional[float] = None,
453+
postprocess_components: Optional[List[int]] = None,
439454
) -> List[int]:
440455
"""Label SGN components using graph connected components.
441456
@@ -446,6 +461,8 @@ def label_components_sgn(
446461
min_component_length: Minimal length for filtering out connected components.
447462
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
448463
iterations_erode: Number of iterations for erosion, normally determined automatically.
464+
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
465+
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
449466
450467
Returns:
451468
List of component label for each point in dataframe. 0 - background, then in descending order of size
@@ -456,15 +473,14 @@ def label_components_sgn(
456473
table = table[table.n_pixels >= min_size]
457474

458475
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
459-
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
476+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode,
477+
postprocess_threshold=postprocess_threshold,
478+
postprocess_components=postprocess_components)
460479

461480
# add size-filtered objects to have same initial length
462481
table = pd.concat([table, entries_filtered], ignore_index=True)
463482
table.sort_values("label_id")
464483

465-
length_components = [len(c) for c in components]
466-
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
467-
468484
component_labels = [0 for _ in range(len(table))]
469485
# be aware of 'label_id' of dataframe starting at 1
470486
for lab, comp in enumerate(components):

0 commit comments

Comments
 (0)