Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies:
- s3fs
- torch_em
- z5py
- zarr
# Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet
- zarr <3
7 changes: 6 additions & 1 deletion flamingo_tools/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import zarr
from elf.io import open_file

try:
from zarr.abc.store import Store
except ImportError:
from zarr._storage.store import BaseStore as Store


def _parse_shape(metadata_file):
depth, height, width = None, None, None
Expand Down Expand Up @@ -62,7 +67,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]:
return x


def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike:
def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike:
"""Read flamingo image data, stored in various formats.
Args:
Expand Down
7 changes: 6 additions & 1 deletion flamingo_tools/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import s3fs
import zarr

try:
from zarr.abc.store import Store
except ImportError:
from zarr._storage.store import BaseStore as Store


# Dedicated bucket for cochlea lightsheet project
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
Expand Down Expand Up @@ -93,7 +98,7 @@ def get_s3_path(
bucket_name: Optional[str] = None,
service_endpoint: Optional[str] = None,
credential_file: Optional[str] = None,
) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]:
) -> Tuple[Store, s3fs.core.S3FileSystem]:
"""Get S3 path for a file or folder and file system based on S3 parameters and credentials.
Args:
Expand Down
264 changes: 262 additions & 2 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import math
import multiprocessing as mp
from concurrent import futures
from typing import Callable, Tuple, Optional
from typing import Callable, List, Optional, Tuple

import elf.parallel as parallel
import numpy as np
import nifty.tools as nt
import networkx as nx
import pandas as pd

from elf.io import open_file
from scipy.spatial import distance
from scipy.sparse import csr_matrix
from scipy.spatial import distance
from scipy.spatial import cKDTree, ConvexHull
from skimage import measure
from sklearn.neighbors import NearestNeighbors
Expand Down Expand Up @@ -205,3 +207,261 @@ def filter_chunk(block_id):
)

return n_ids, n_ids_filtered


def erode_subset(
table: pd.DataFrame,
iterations: int = 1,
min_cells: Optional[int] = None,
threshold: int = 35,
keyword: str = "distance_nn100",
) -> pd.DataFrame:
"""Erode coordinates of dataframe according to a keyword and a threshold.
Use a copy of the dataframe as an input, if it should not be edited.

Args:
table: Dataframe of segmentation table.
iterations: Number of steps for erosion process.
min_cells: Minimal number of rows. The erosion is stopped after falling below this limit.
threshold: Upper threshold for removing elements according to the given keyword.
keyword: Keyword of dataframe for erosion.

Returns:
The dataframe containing elements left after the erosion.
"""
print("initial length", len(table))
n_neighbors = 100
for i in range(iterations):
table = table[table[keyword] < threshold]

distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors)

if min_cells is not None and len(distance_avg) < min_cells:
print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion")
break

table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg)

print(f"{i}-th iteration, length of subset {len(table)}")

return table


def downscaled_centroids(
table: pd.DataFrame,
scale_factor: int,
ref_dimensions: Optional[Tuple[float, float, float]] = None,
downsample_mode: str = "accumulated",
) -> np.typing.NDArray:
"""Downscale centroids in dataframe.

Args:
table: Dataframe of segmentation table.
scale_factor: Factor for downscaling coordinates.
ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied.
downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'.

Returns:
The downscaled array
"""
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids]

if ref_dimensions is None:
bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"]))
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions])
new_array = np.zeros(bounding_dimensions_scaled)

else:
bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions])
new_array = np.zeros(bounding_dimensions_scaled)

if downsample_mode == "accumulated":
for c in centroids_scaled:
new_array[int(c[0]), int(c[1]), int(c[2])] += 1

elif downsample_mode == "capped":
for c in centroids_scaled:
new_array[int(c[0]), int(c[1]), int(c[2])] = 1

elif downsample_mode == "components":
if "component_labels" not in table.columns:
raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.")
component_labels = list(table["component_labels"])
for comp, centr in zip(component_labels, centroids_scaled):
if comp != 0:
new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp

else:
raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.")

new_array = np.round(new_array).astype(int)

return new_array


def components_sgn(
table: pd.DataFrame,
keyword: str = "distance_nn100",
threshold_erode: Optional[float] = None,
postprocess_graph: bool = False,
min_component_length: int = 50,
min_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
) -> List[List[int]]:
"""Eroding the SGN segmentation.

Args:
table: Dataframe of segmentation table.
keyword: Keyword of the dataframe column for erosion.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
postprocess_graph: Post-process graph connected components by searching for near points.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
Subgraph components as lists of label_ids of dataframe.
"""
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
labels = [int(i) for i in list(table["label_id"])]

distance_nn = list(table[keyword])
distance_nn.sort()

if len(table) < 20000:
iterations = iterations_erode if iterations_erode is not None else 0
min_cells = None
average_dist = int(distance_nn[int(len(table) * 0.8)])
threshold = threshold_erode if threshold_erode is not None else average_dist
else:
iterations = iterations_erode if iterations_erode is not None else 15
min_cells = 20000
threshold = threshold_erode if threshold_erode is not None else 40

print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")

new_subset = erode_subset(table.copy(), iterations=iterations,
threshold=threshold, min_cells=min_cells, keyword=keyword)

# create graph from coordinates of eroded subset
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
labels_subset = [int(i) for i in list(new_subset["label_id"])]
coords = {}
for index, element in zip(labels_subset, centroids_subset):
coords[index] = element

graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
for i in coords:
for j in coords:
if i < j:
dist = math.dist(coords[i], coords[j])
if dist <= min_edge_distance:
graph.add_edge(i, j, weight=dist)

components = list(nx.connected_components(graph))

# remove connected components with less nodes than threshold min_component_length
for component in components:
if len(component) < min_component_length:
for c in component:
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]

# add original coordinates closer to eroded component than threshold
if postprocess_graph:
threshold = 15
for label_id, centr in zip(labels, centroids):
if label_id not in labels_subset:
add_coord = []
for comp_index, component in enumerate(components):
for comp_label in component:
dist = math.dist(centr, centroids[comp_label - 1])
if dist <= threshold:
add_coord.append([comp_index, label_id])
break
if len(add_coord) != 0:
components[add_coord[0][0]].append(add_coord[0][1])

return components


def label_components(
table: pd.DataFrame,
min_size: int = 1000,
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
) -> List[int]:
"""Label components using graph connected components.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
List of component label for each point in dataframe. 0 - background, then in descending order of size
"""

# First, apply the size filter.
entries_filtered = table[table.n_pixels < min_size]
table = table[table.n_pixels >= min_size]

components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)

# add size-filtered objects to have same initial length
table = pd.concat([table, entries_filtered], ignore_index=True)
table.sort_values("label_id")

length_components = [len(c) for c in components]
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))

component_labels = [0 for _ in range(len(table))]
# be aware of 'label_id' of dataframe starting at 1
for lab, comp in enumerate(components):
for comp_index in comp:
component_labels[comp_index - 1] = lab + 1

return component_labels


def postprocess_sgn_seg(
table: pd.DataFrame,
min_size: int = 1000,
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
) -> pd.DataFrame:
"""Postprocessing SGN segmentation of cochlea.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
Dataframe with component labels.
"""

comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)

table.loc[:, "component_labels"] = comp_labels

return table
2 changes: 1 addition & 1 deletion scripts/prediction/expand_seg_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(
neighbor_counts = [n[0] for n in neighbor_counts]
tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts

tsv_table.to_csv(out_path, sep="\t")
tsv_table.to_csv(out_path, sep="\t", index=False)


if __name__ == "__main__":
Expand Down
Loading