Skip to content

Commit 67193f5

Browse files
committed
Removed post-processing with downsampling
1 parent c51947d commit 67193f5

File tree

1 file changed

+47
-180
lines changed

1 file changed

+47
-180
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 47 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
import pandas as pd
1111

1212
from elf.io import open_file
13-
from scipy.ndimage import binary_fill_holes
14-
from scipy.ndimage import distance_transform_edt
15-
from scipy.ndimage import label
1613
from scipy.sparse import csr_matrix
1714
from scipy.spatial import distance
1815
from scipy.spatial import cKDTree, ConvexHull
@@ -212,11 +209,6 @@ def filter_chunk(block_id):
212209
return n_ids, n_ids_filtered
213210

214211

215-
# Postprocess segmentation by erosion using the above spatial statistics.
216-
# Currently implemented using downscaling and looking for connected components
217-
# TODO: Change implementation to graph connected components.
218-
219-
220212
def erode_subset(
221213
table: pd.DataFrame,
222214
iterations: Optional[int] = 1,
@@ -242,7 +234,6 @@ def erode_subset(
242234
for i in range(iterations):
243235
table = table[table[keyword] < threshold]
244236

245-
# TODO: support other spatial statistics
246237
distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)
247238

248239
if min_cells is not None and len(distance_avg) < min_cells:
@@ -309,72 +300,43 @@ def downscaled_centroids(
309300
return new_array
310301

311302

312-
def coordinates_in_downscaled_blocks(
313-
table: pd.DataFrame,
314-
down_array: np.typing.NDArray,
315-
scale_factor: float,
316-
distance_component: Optional[int] = 0,
317-
) -> List[int]:
318-
"""Checking if coordinates are within the downscaled array.
319-
320-
Args:
321-
table: Dataframe of segmentation table.
322-
down_array: Downscaled array.
323-
scale_factor: Factor which was used for downscaling.
324-
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
325-
326-
Returns:
327-
A binary list representing whether the dataframe coordinates are within the array.
328-
"""
329-
# fill holes in down-sampled array
330-
down_array[down_array > 0] = 1
331-
down_array = binary_fill_holes(down_array).astype(np.uint8)
332-
333-
# check if input coordinates are within down-sampled blocks
334-
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
335-
centroids = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids]
336-
337-
distance_map = distance_transform_edt(down_array == 0)
338-
339-
centroids_binary = []
340-
for c in centroids:
341-
coord = (int(c[0]), int(c[1]), int(c[2]))
342-
if down_array[coord] != 0:
343-
centroids_binary.append(1)
344-
elif distance_map[coord] <= distance_component:
345-
centroids_binary.append(1)
346-
else:
347-
centroids_binary.append(0)
348-
349-
return centroids_binary
350-
351-
352-
def erode_sgn_seg_graph(
303+
def components_sgn(
353304
table: pd.DataFrame,
354305
keyword: Optional[str] = "distance_nn100",
355306
threshold_erode: Optional[float] = None,
307+
postprocess_graph: Optional[bool] = False,
308+
min_component_length: Optional[int] = 50,
309+
min_edge_distance: Optional[float] = 30,
310+
iterations_erode: Optional[int] = None,
356311
) -> List[List[int]]:
357312
"""Eroding the SGN segmentation.
358313
359314
Args:
360315
table: Dataframe of segmentation table.
361316
keyword: Keyword of the dataframe column for erosion.
362317
threshold_erode: Threshold of column value after erosion step with spatial statistics.
318+
postprocess_graph: Post-process graph connected components by searching for near points.
319+
min_component_length: Minimal length for filtering out connected components.
320+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
321+
iterations_erode: Number of iterations for erosion, normally determined automatically.
363322
364323
Returns:
365324
Subgraph components as lists of label_ids of dataframe.
366325
"""
326+
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
327+
labels = [int(i) for i in list(table["label_id"])]
328+
367329
print("initial length", len(table))
368330
distance_nn = list(table[keyword])
369331
distance_nn.sort()
370332

371333
if len(table) < 20000:
372-
iterations = 1
334+
iterations = iterations_erode if iterations_erode is not None else 0
373335
min_cells = None
374336
average_dist = int(distance_nn[int(len(table) * 0.8)])
375337
threshold = threshold_erode if threshold_erode is not None else average_dist
376338
else:
377-
iterations = 15
339+
iterations = iterations_erode if iterations_erode is not None else 15
378340
min_cells = 20000
379341
threshold = threshold_erode if threshold_erode is not None else 40
380342

@@ -394,183 +356,88 @@ def erode_sgn_seg_graph(
394356
for num, pos in coords.items():
395357
graph.add_node(num, pos=pos)
396358

397-
# create edges between points whose distance is less than threshold
398-
threshold = 30
359+
# create edges between points whose distance is less than threshold min_edge_distance
399360
for i in coords:
400361
for j in coords:
401362
if i < j:
402363
dist = math.dist(coords[i], coords[j])
403-
if dist <= threshold:
364+
if dist <= min_edge_distance:
404365
graph.add_edge(i, j, weight=dist)
405366

406367
components = list(nx.connected_components(graph))
407368

408-
# remove connected components with less nodes than threshold
409-
min_length = 100
369+
# remove connected components with less nodes than threshold min_component_length
410370
for component in components:
411-
if len(component) < min_length:
371+
if len(component) < min_component_length:
412372
for c in component:
413373
graph.remove_node(c)
414374

415-
components = list(nx.connected_components(graph))
375+
components = [list(s) for s in nx.connected_components(graph)]
376+
377+
# add original coordinates closer to eroded component than threshold
378+
if postprocess_graph:
379+
threshold = 15
380+
for label_id, centr in zip(labels, centroids):
381+
if label_id not in labels_subset:
382+
add_coord = []
383+
for comp_index, component in enumerate(components):
384+
for comp_label in component:
385+
dist = math.dist(centr, centroids[comp_label - 1])
386+
if dist <= threshold:
387+
add_coord.append([comp_index, label_id])
388+
break
389+
if len(add_coord) != 0:
390+
components[add_coord[0][0]].append(add_coord[0][1])
416391

417392
return components
418393

419394

420-
def erode_sgn_seg_downscaling(
395+
def label_components(
421396
table: pd.DataFrame,
422-
keyword: Optional[str] = "distance_nn100",
423-
filter_small_components: Optional[int] = None,
424-
scale_factor: Optional[float] = 20,
425397
threshold_erode: Optional[float] = None,
426-
) -> Tuple[np.typing.NDArray, np.typing.NDArray]:
427-
"""Eroding the SGN segmentation.
428-
429-
Args:
430-
table: Dataframe of segmentation table.
431-
keyword: Keyword of the dataframe column for erosion.
432-
filter_small_components: Filter components smaller after n blocks after labeling.
433-
scale_factor: Scaling for downsampling.
434-
threshold_erode: Threshold of column value after erosion step with spatial statistics.
435-
436-
Returns:
437-
The labeled components of the downscaled, eroded coordinates.
438-
The larget connected component of the labeled components.
439-
"""
440-
ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
441-
print("initial length", len(table))
442-
distance_nn = list(table[keyword])
443-
distance_nn.sort()
444-
445-
if len(table) < 20000:
446-
iterations = 1
447-
min_cells = None
448-
average_dist = int(distance_nn[int(len(table) * 0.8)])
449-
threshold = threshold_erode if threshold_erode is not None else average_dist
450-
else:
451-
iterations = 15
452-
min_cells = 20000
453-
threshold = threshold_erode if threshold_erode is not None else 40
454-
455-
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
456-
457-
new_subset = erode_subset(table.copy(), iterations=iterations,
458-
threshold=threshold, min_cells=min_cells, keyword=keyword)
459-
460-
eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions)
461-
462-
# Label connected components
463-
labeled, num_features = label(eroded_arr)
464-
465-
# Find the largest component
466-
sizes = [(labeled == i).sum() for i in range(1, num_features + 1)]
467-
largest_label = np.argmax(sizes) + 1
468-
469-
# Extract only the largest component
470-
largest_component = (labeled == largest_label).astype(np.uint8)
471-
largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8)
472-
473-
# filter small sizes
474-
if filter_small_components is not None:
475-
for (size, feature) in zip(sizes, range(1, num_features + 1)):
476-
if size < filter_small_components:
477-
labeled[labeled == feature] = 0
478-
479-
return labeled, largest_component_filtered
480-
481-
482-
def get_components(
483-
table: pd.DataFrame,
484-
labeled: np.typing.NDArray,
485-
scale_factor: float,
486-
distance_component: Optional[int] = 0,
398+
min_component_length: Optional[int] = 50,
399+
min_edge_distance: Optional[float] = 30,
400+
iterations_erode: Optional[int] = None,
487401
) -> List[int]:
488-
"""Indexing coordinates according to labeled array.
489-
490-
Args:
491-
table: Dataframe of segmentation table.
492-
labeled: Array containing differently labeled components.
493-
scale_factor: Scaling for downsampling.
494-
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
495-
496-
Returns:
497-
List of component labels.
498-
"""
499-
unique_labels = list(np.unique(labeled))
500-
501-
# sort non-background labels according to size, descending
502-
unique_labels = [i for i in unique_labels if i != 0]
503-
sizes = [(labeled == i).sum() for i in unique_labels]
504-
sizes, unique_labels = zip(*sorted(zip(sizes, unique_labels), reverse=True))
505-
506-
component_labels = [0 for _ in range(len(table))]
507-
for label_index, l in enumerate(unique_labels):
508-
label_arr = (labeled == l).astype(np.uint8)
509-
centroids_binary = coordinates_in_downscaled_blocks(table, label_arr,
510-
scale_factor, distance_component=distance_component)
511-
for num, c in enumerate(centroids_binary):
512-
if c != 0:
513-
component_labels[num] = label_index + 1
514-
515-
return component_labels
516-
517-
518-
def component_labels_graph(table: pd.DataFrame) -> List[int]:
519402
"""Label components using graph connected components.
520403
521404
Args:
522405
table: Dataframe of segmentation table.
406+
threshold_erode: Threshold of column value after erosion step with spatial statistics.
407+
min_component_length: Minimal length for filtering out connected components.
408+
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
409+
iterations_erode: Number of iterations for erosion, normally determined automatically.
523410
524411
Returns:
525-
List of component label for each point in dataframe.
412+
List of component label for each point in dataframe. 0 - background, then in descending order of size
526413
"""
527-
components = erode_sgn_seg_graph(table)
414+
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
415+
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
528416

529417
length_components = [len(c) for c in components]
530418
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
531419

532420
component_labels = [0 for _ in range(len(table))]
421+
# be aware of 'label_id' of dataframe starting at 1
533422
for lab, comp in enumerate(components):
534423
for comp_index in comp:
535424
component_labels[comp_index - 1] = lab + 1
536425

537426
return component_labels
538427

539428

540-
def component_labels_downscaling(table: pd.DataFrame, scale_factor: float = 20) -> List[int]:
541-
"""Label components using downscaling and connected components.
542-
543-
Args:
544-
table: Dataframe of segmentation table.
545-
scale_factor: Factor for downscaling.
546-
547-
Returns:
548-
List of component label for each point in dataframe.
549-
"""
550-
labeled, largest_component = erode_sgn_seg_downscaling(table, filter_small_components=10,
551-
scale_factor=scale_factor, threshold_erode=None)
552-
component_labels = get_components(table, labeled, scale_factor, distance_component=1)
553-
554-
return component_labels
555-
556-
557429
def postprocess_sgn_seg(
558430
table: pd.DataFrame,
559-
postprocess_type: Optional[str] = "downsampling",
560431
) -> pd.DataFrame:
561432
"""Postprocessing SGN segmentation of cochlea.
562433
563434
Args:
564435
table: Dataframe of segmentation table.
565-
postprocess_type: Postprocessing method, either 'downsampling' or 'graph'.
566436
567437
Returns:
568438
Dataframe with component labels.
569439
"""
570-
if postprocess_type == "downsampling":
571-
component_labels = component_labels_downscaling(table)
572-
elif postprocess_type == "graph":
573-
component_labels = component_labels_graph(table)
440+
component_labels = label_components(table)
574441

575442
table.loc[:, "component_labels"] = component_labels
576443

0 commit comments

Comments
 (0)