Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 239 additions & 1 deletion flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing as mp
import os
from concurrent import futures
from typing import Callable, Tuple, Optional

Expand All @@ -8,8 +9,11 @@
import pandas as pd

from elf.io import open_file
from scipy.spatial import distance
from scipy.ndimage import binary_fill_holes
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import label
from scipy.sparse import csr_matrix
from scipy.spatial import distance
from scipy.spatial import cKDTree, ConvexHull
from skimage import measure
from sklearn.neighbors import NearestNeighbors
Expand Down Expand Up @@ -205,3 +209,237 @@ def filter_chunk(block_id):
)

return n_ids, n_ids_filtered


# Postprocess segmentation by erosion using the above spatial statistics.
# Currently implemented using downscaling and looking for connected components
# TODO: Change implementation to graph connected components.


def erode_subset(
table: pd.DataFrame,
iterations: Optional[int] = 1,
min_cells: Optional[int] = None,
threshold: Optional[int] = 35,
keyword: Optional[str] = "distance_nn100",
) -> pd.DataFrame:
"""Erode coordinates of dataframe according to a keyword and a threshold.
Use a copy of the dataframe as an input, if it should not be edited.
Args:
table: Dataframe of segmentation table.
iterations: Number of steps for erosion process.
min_cells: Minimal number of rows. The erosion is stopped before reaching this number.
threshold: Upper threshold for removing elements according to the given keyword.
keyword: Keyword of dataframe for erosion.
Returns:
The dataframe containing elements left after the erosion.
"""
print("initial length", len(table))
n_neighbors = 100
for i in range(iterations):
table = table[table[keyword] < threshold]

# TODO: support other spatial statistics
distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)

if min_cells is not None and len(distance_avg) < min_cells:
print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion")
break

table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg)

print(f"{i}-th iteration, length of subset {len(table)}")

return table


def downscaled_centroids(
table: pd.DataFrame,
scale_factor: int,
ref_dimensions: Optional[Tuple[float,float,float]] = None,
capped: Optional[bool] = True,
) -> np.typing.NDArray:
"""Downscale centroids in dataframe.
Args:
table: Dataframe of segmentation table.
scale_factor: Factor for downscaling coordinates.
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
capped: Flag for capping output of array at 1 for the creation of a binary mask.
Returns:
The downscaled array
"""
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids]

if ref_dimensions is None:
bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions])
new_array = np.zeros(bounding_dimensions_scaled)

else:
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
new_array = np.zeros(bounding_dimensions_scaled)

for c in centroids_scaled:
new_array[int(c[0]), int(c[1]), int(c[2])] += 1

array_downscaled = np.round(new_array).astype(int)

if capped:
array_downscaled[array_downscaled >= 1] = 1

return array_downscaled


def coordinates_in_downscaled_blocks(
table: pd.DataFrame,
down_array: np.typing.NDArray,
scale_factor: float,
distance_component: Optional[int] = 0,
) -> list:
"""Checking if coordinates are within the downscaled array.
Args:
table: Dataframe of segmentation table.
down_array: Downscaled array.
scale_factor: Factor which was used for downscaling.
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
Returns:
A binary list representing whether the dataframe coordinates are within the array.
"""
# fill holes in down-sampled array
down_array[down_array > 0] = 1
down_array = binary_fill_holes(down_array).astype(np.uint8)

# check if input coordinates are within down-sampled blocks
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
centroids_scaled = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids]

distance_map = distance_transform_edt(down_array == 0)

centroids_binary = []
for c in centroids_scaled:
coord = (int(c[0]), int(c[1]), int(c[2]))
if down_array[coord] != 0:
centroids_binary.append(1)
elif distance_map[coord] <= distance_component:
centroids_binary.append(1)
else:
centroids_binary.append(0)

return centroids_binary


def erode_sgn_seg(
table: pd.DataFrame,
keyword: Optional[str] = "distance_nn100",
filter_small_components: Optional[int] = None,
scale_factor: Optional[float] = 20,
threshold_erode: Optional[float] = None,
) -> Tuple[pd.DataFrame,np.typing.NDArray,np.typing.NDArray,np.typing.NDArray]:
"""Eroding the SGN segmentation.
Args:
table: Dataframe of segmentation table.
keyword: Keyword of the dataframe column for erosion.
filter_small_components: Filter components smaller after n blocks after labeling.
scale_factor: Scaling for downsampling.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
Returns:
The labeled components of the downscaled, eroded coordinates.
The larget connected component of the labeled components.
"""

ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
print("initial length", len(table))
distance_nn = list(table[keyword])
distance_nn.sort()

if len(table) < 20000:
iterations = 1
min_cells = None
average_dist = int(distance_nn[int(len(table) * 0.8)])
threshold = threshold_erode if threshold_erode is not None else average_dist
else:
iterations = 15
min_cells = 20000
threshold = threshold_erode if threshold_erode is not None else 40

print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")

new_subset = erode_subset(table.copy(), iterations=iterations,
threshold=threshold, min_cells=min_cells, keyword=keyword)
eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions)
# Label connected components
labeled, num_features = label(eroded_arr)

# Find the largest component
sizes = [(labeled == i).sum() for i in range(1, num_features + 1)]
largest_label = np.argmax(sizes) + 1

# Extract only the largest component
largest_component = (labeled == largest_label).astype(np.uint8)
largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8)

#filter small sizes
if filter_small_components is not None:
for (size, feature) in zip(sizes, range(1, num_features + 1)):
if size < filter_small_components:
labeled[labeled == feature] = 0

return labeled, largest_component_filtered


def get_components(table: pd.DataFrame,
labeled: np.typing.NDArray,
scale_factor: float,
distance_component: Optional[int] = 0,
) -> list:
"""Indexing coordinates according to labeled array.
Args:
table: Dataframe of segmentation table.
labeled: Array containing differently labeled components.
scale_factor: Scaling for downsampling.
distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included.
Returns:
List of component labels.
"""
unique_labels = list(np.unique(labeled))
component_labels = [0 for _ in range(len(table))]
for label_index, l in enumerate(unique_labels):
if l != 0:
label_arr = (labeled == l).astype(np.uint8)
centroids_binary = coordinates_in_downscaled_blocks(table, label_arr,
scale_factor, distance_component = distance_component)
for num, c in enumerate(centroids_binary):
if c != 0:
component_labels[num] = label_index
return component_labels


def postprocess_sgn_seg(table: pd.DataFrame, scale_factor: Optional[float] = 20) -> pd.DataFrame:
"""Postprocessing SGN segmentation of cochlea.
Args:
table: Dataframe of segmentation table.
scale_factor: Scaling for downsampling.
Returns:
Dataframe with component labels.
"""
labeled, largest_component = erode_sgn_seg(table, filter_small_labels=10,
scale_factor=scale_factor, threshold_erode=None)

component_labels = get_components(table, labeled, scale_factor, distance_component = 1)

table.loc[:, "component_labels"] = component_labels

return table
2 changes: 1 addition & 1 deletion scripts/prediction/expand_seg_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(
neighbor_counts = [n[0] for n in neighbor_counts]
tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts

tsv_table.to_csv(out_path, sep="\t")
tsv_table.to_csv(out_path, sep="\t", index=False)


if __name__ == "__main__":
Expand Down