From 1ceafb85610ebbe40d6ae83893f04c5d9257ebf6 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 30 Apr 2025 17:58:17 +0200 Subject: [PATCH 1/8] Postprocessing cochlea segmentation using erosion --- flamingo_tools/segmentation/postprocessing.py | 240 +++++++++++++++++- scripts/prediction/expand_seg_table.py | 2 +- 2 files changed, 240 insertions(+), 2 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 7ad987b..2f32dab 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import os from concurrent import futures from typing import Callable, Tuple, Optional @@ -8,8 +9,11 @@ import pandas as pd from elf.io import open_file -from scipy.spatial import distance +from scipy.ndimage import binary_fill_holes +from scipy.ndimage import distance_transform_edt +from scipy.ndimage import label from scipy.sparse import csr_matrix +from scipy.spatial import distance from scipy.spatial import cKDTree, ConvexHull from skimage import measure from sklearn.neighbors import NearestNeighbors @@ -205,3 +209,237 @@ def filter_chunk(block_id): ) return n_ids, n_ids_filtered + + +# Postprocess segmentation by erosion using the above spatial statistics. +# Currently implemented using downscaling and looking for connected components +# TODO: Change implementation to graph connected components. + + +def erode_subset( + table: pd.DataFrame, + iterations: Optional[int] = 1, + min_cells: Optional[int] = None, + threshold: Optional[int] = 35, + keyword: Optional[str] = "distance_nn100", +) -> pd.DataFrame: + """Erode coordinates of dataframe according to a keyword and a threshold. + Use a copy of the dataframe as an input, if it should not be edited. + + Args: + table: Dataframe of segmentation table. + iterations: Number of steps for erosion process. + min_cells: Minimal number of rows. The erosion is stopped before reaching this number. + threshold: Upper threshold for removing elements according to the given keyword. + keyword: Keyword of dataframe for erosion. + + Returns: + The dataframe containing elements left after the erosion. + """ + print("initial length", len(table)) + n_neighbors = 100 + for i in range(iterations): + table = table[table[keyword] < threshold] + + # TODO: support other spatial statistics + distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors) + + if min_cells is not None and len(distance_avg) < min_cells: + print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion") + break + + table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg) + + print(f"{i}-th iteration, length of subset {len(table)}") + + return table + + +def downscaled_centroids( + table: pd.DataFrame, + scale_factor: int, + ref_dimensions: Optional[Tuple[float,float,float]] = None, + capped: Optional[bool] = True, +) -> np.typing.NDArray: + """Downscale centroids in dataframe. + + Args: + table: Dataframe of segmentation table. + scale_factor: Factor for downscaling coordinates. + ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. + capped: Flag for capping output of array at 1 for the creation of a binary mask. + + Returns: + The downscaled array + """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids] + + if ref_dimensions is None: + bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions]) + new_array = np.zeros(bounding_dimensions_scaled) + + else: + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions]) + new_array = np.zeros(bounding_dimensions_scaled) + + for c in centroids_scaled: + new_array[int(c[0]), int(c[1]), int(c[2])] += 1 + + array_downscaled = np.round(new_array).astype(int) + + if capped: + array_downscaled[array_downscaled >= 1] = 1 + + return array_downscaled + + +def coordinates_in_downscaled_blocks( + table: pd.DataFrame, + down_array: np.typing.NDArray, + scale_factor: float, + distance_component: Optional[int] = 0, +) -> list: + """Checking if coordinates are within the downscaled array. + + Args: + table: Dataframe of segmentation table. + down_array: Downscaled array. + scale_factor: Factor which was used for downscaling. + distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included. + + Returns: + A binary list representing whether the dataframe coordinates are within the array. + """ + # fill holes in down-sampled array + down_array[down_array > 0] = 1 + down_array = binary_fill_holes(down_array).astype(np.uint8) + + # check if input coordinates are within down-sampled blocks + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + centroids_scaled = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids] + + distance_map = distance_transform_edt(down_array == 0) + + centroids_binary = [] + for c in centroids_scaled: + coord = (int(c[0]), int(c[1]), int(c[2])) + if down_array[coord] != 0: + centroids_binary.append(1) + elif distance_map[coord] <= distance_component: + centroids_binary.append(1) + else: + centroids_binary.append(0) + + return centroids_binary + + +def erode_sgn_seg( + table: pd.DataFrame, + keyword: Optional[str] = "distance_nn100", + filter_small_components: Optional[int] = None, + scale_factor: Optional[float] = 20, + threshold_erode: Optional[float] = None, +) -> Tuple[pd.DataFrame,np.typing.NDArray,np.typing.NDArray,np.typing.NDArray]: + """Eroding the SGN segmentation. + + Args: + table: Dataframe of segmentation table. + keyword: Keyword of the dataframe column for erosion. + filter_small_components: Filter components smaller after n blocks after labeling. + scale_factor: Scaling for downsampling. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + + Returns: + The labeled components of the downscaled, eroded coordinates. + The larget connected component of the labeled components. + """ + + ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) + print("initial length", len(table)) + distance_nn = list(table[keyword]) + distance_nn.sort() + + if len(table) < 20000: + iterations = 1 + min_cells = None + average_dist = int(distance_nn[int(len(table) * 0.8)]) + threshold = threshold_erode if threshold_erode is not None else average_dist + else: + iterations = 15 + min_cells = 20000 + threshold = threshold_erode if threshold_erode is not None else 40 + + print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") + + new_subset = erode_subset(table.copy(), iterations=iterations, + threshold=threshold, min_cells=min_cells, keyword=keyword) + eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions) + # Label connected components + labeled, num_features = label(eroded_arr) + + # Find the largest component + sizes = [(labeled == i).sum() for i in range(1, num_features + 1)] + largest_label = np.argmax(sizes) + 1 + + # Extract only the largest component + largest_component = (labeled == largest_label).astype(np.uint8) + largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8) + + #filter small sizes + if filter_small_components is not None: + for (size, feature) in zip(sizes, range(1, num_features + 1)): + if size < filter_small_components: + labeled[labeled == feature] = 0 + + return labeled, largest_component_filtered + + +def get_components(table: pd.DataFrame, + labeled: np.typing.NDArray, + scale_factor: float, + distance_component: Optional[int] = 0, +) -> list: + """Indexing coordinates according to labeled array. + + Args: + table: Dataframe of segmentation table. + labeled: Array containing differently labeled components. + scale_factor: Scaling for downsampling. + distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included. + + Returns: + List of component labels. + """ + unique_labels = list(np.unique(labeled)) + component_labels = [0 for _ in range(len(table))] + for label_index, l in enumerate(unique_labels): + if l != 0: + label_arr = (labeled == l).astype(np.uint8) + centroids_binary = coordinates_in_downscaled_blocks(table, label_arr, + scale_factor, distance_component = distance_component) + for num, c in enumerate(centroids_binary): + if c != 0: + component_labels[num] = label_index + return component_labels + + +def postprocess_sgn_seg(table: pd.DataFrame, scale_factor: Optional[float] = 20) -> pd.DataFrame: + """Postprocessing SGN segmentation of cochlea. + + Args: + table: Dataframe of segmentation table. + scale_factor: Scaling for downsampling. + + Returns: + Dataframe with component labels. + """ + labeled, largest_component = erode_sgn_seg(table, filter_small_labels=10, + scale_factor=scale_factor, threshold_erode=None) + + component_labels = get_components(table, labeled, scale_factor, distance_component = 1) + + table.loc[:, "component_labels"] = component_labels + + return table \ No newline at end of file diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py index 1b6d3f7..b8a71ff 100644 --- a/scripts/prediction/expand_seg_table.py +++ b/scripts/prediction/expand_seg_table.py @@ -61,7 +61,7 @@ def main( neighbor_counts = [n[0] for n in neighbor_counts] tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts - tsv_table.to_csv(out_path, sep="\t") + tsv_table.to_csv(out_path, sep="\t", index=False) if __name__ == "__main__": From 2fa29b154a099b7bf27e6f20214103c833e9158f Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 7 May 2025 10:42:10 +0200 Subject: [PATCH 2/8] Sort components according to their size --- flamingo_tools/segmentation/postprocessing.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 2f32dab..56126d2 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -413,15 +413,21 @@ def get_components(table: pd.DataFrame, List of component labels. """ unique_labels = list(np.unique(labeled)) + + # sort non-background labels according to size, descending + unique_labels = [i for i in unique_labels if i != 0] + sizes = [(labeled == i).sum() for i in unique_labels] + sizes, unique_labels = zip(*sorted(zip(sizes, unique_labels), reverse=True)) + component_labels = [0 for _ in range(len(table))] for label_index, l in enumerate(unique_labels): - if l != 0: - label_arr = (labeled == l).astype(np.uint8) - centroids_binary = coordinates_in_downscaled_blocks(table, label_arr, - scale_factor, distance_component = distance_component) - for num, c in enumerate(centroids_binary): - if c != 0: - component_labels[num] = label_index + label_arr = (labeled == l).astype(np.uint8) + centroids_binary = coordinates_in_downscaled_blocks(table, label_arr, + scale_factor, distance_component = distance_component) + for num, c in enumerate(centroids_binary): + if c != 0: + component_labels[num] = label_index + 1 + return component_labels From c6e2b04341289dc98d85737b2476b3c3441c4af8 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 7 May 2025 12:00:31 +0200 Subject: [PATCH 3/8] Added graph connected components for postprocessing --- flamingo_tools/segmentation/postprocessing.py | 181 +++++++++++++++--- 1 file changed, 153 insertions(+), 28 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 56126d2..763dc82 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,11 +1,12 @@ +import math import multiprocessing as mp -import os from concurrent import futures -from typing import Callable, Tuple, Optional +from typing import Callable, List, Optional, Tuple import elf.parallel as parallel import numpy as np import nifty.tools as nt +import networkx as nx import pandas as pd from elf.io import open_file @@ -258,8 +259,8 @@ def erode_subset( def downscaled_centroids( table: pd.DataFrame, scale_factor: int, - ref_dimensions: Optional[Tuple[float,float,float]] = None, - capped: Optional[bool] = True, + ref_dimensions: Optional[Tuple[float, float, float]] = None, + downsample_mode: Optional[str] = "accumulated", ) -> np.typing.NDArray: """Downscale centroids in dataframe. @@ -267,7 +268,7 @@ def downscaled_centroids( table: Dataframe of segmentation table. scale_factor: Factor for downscaling coordinates. ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. - capped: Flag for capping output of array at 1 for the creation of a binary mask. + downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components' Returns: The downscaled array @@ -284,15 +285,27 @@ def downscaled_centroids( bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions]) new_array = np.zeros(bounding_dimensions_scaled) - for c in centroids_scaled: - new_array[int(c[0]), int(c[1]), int(c[2])] += 1 + if downsample_mode == "accumulated": + for c in centroids_scaled: + new_array[int(c[0]), int(c[1]), int(c[2])] += 1 - array_downscaled = np.round(new_array).astype(int) + elif downsample_mode == "capped": + new_array = np.round(new_array).astype(int) + new_array[new_array >= 1] = 1 - if capped: - array_downscaled[array_downscaled >= 1] = 1 + elif downsample_mode == "components": + if "component_labels" not in table.columns: + raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.") + component_labels = list(table["component_labels"]) + for comp, centr in zip(component_labels, centroids_scaled): + if comp != 0: + new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp + new_array = np.round(new_array).astype(int) - return array_downscaled + else: + raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.") + + return new_array def coordinates_in_downscaled_blocks( @@ -300,7 +313,7 @@ def coordinates_in_downscaled_blocks( down_array: np.typing.NDArray, scale_factor: float, distance_component: Optional[int] = 0, -) -> list: +) -> List[int]: """Checking if coordinates are within the downscaled array. Args: @@ -318,12 +331,12 @@ def coordinates_in_downscaled_blocks( # check if input coordinates are within down-sampled blocks centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) - centroids_scaled = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids] + centroids = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids] distance_map = distance_transform_edt(down_array == 0) centroids_binary = [] - for c in centroids_scaled: + for c in centroids: coord = (int(c[0]), int(c[1]), int(c[2])) if down_array[coord] != 0: centroids_binary.append(1) @@ -335,13 +348,81 @@ def coordinates_in_downscaled_blocks( return centroids_binary -def erode_sgn_seg( +def erode_sgn_seg_graph( + table: pd.DataFrame, + keyword: Optional[str] = "distance_nn100", + threshold_erode: Optional[float] = None, +) -> List[List[int]]: + """Eroding the SGN segmentation. + + Args: + table: Dataframe of segmentation table. + keyword: Keyword of the dataframe column for erosion. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + + Returns: + Subgraph components as lists of label_ids of dataframe. + """ + print("initial length", len(table)) + distance_nn = list(table[keyword]) + distance_nn.sort() + + if len(table) < 20000: + iterations = 1 + min_cells = None + average_dist = int(distance_nn[int(len(table) * 0.8)]) + threshold = threshold_erode if threshold_erode is not None else average_dist + else: + iterations = 15 + min_cells = 20000 + threshold = threshold_erode if threshold_erode is not None else 40 + + print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") + + new_subset = erode_subset(table.copy(), iterations=iterations, + threshold=threshold, min_cells=min_cells, keyword=keyword) + + # create graph from coordinates of eroded subset + centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) + labels_subset = [int(i) for i in list(new_subset["label_id"])] + coords = {} + for index, element in zip(labels_subset, centroids_subset): + coords[index] = element + + graph = nx.Graph() + for num, pos in coords.items(): + graph.add_node(num, pos=pos) + + # create edges between points whose distance is less than threshold + threshold = 30 + for i in coords: + for j in coords: + if i < j: + dist = math.dist(coords[i], coords[j]) + if dist <= threshold: + graph.add_edge(i, j, weight=dist) + + components = list(nx.connected_components(graph)) + + # remove connected components with less nodes than threshold + min_length = 100 + for component in components: + if len(component) < min_length: + for c in component: + graph.remove_node(c) + + components = list(nx.connected_components(graph)) + + return components + + +def erode_sgn_seg_downscaling( table: pd.DataFrame, keyword: Optional[str] = "distance_nn100", filter_small_components: Optional[int] = None, scale_factor: Optional[float] = 20, threshold_erode: Optional[float] = None, -) -> Tuple[pd.DataFrame,np.typing.NDArray,np.typing.NDArray,np.typing.NDArray]: +) -> Tuple[np.typing.NDArray, np.typing.NDArray]: """Eroding the SGN segmentation. Args: @@ -355,7 +436,6 @@ def erode_sgn_seg( The labeled components of the downscaled, eroded coordinates. The larget connected component of the labeled components. """ - ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) print("initial length", len(table)) distance_nn = list(table[keyword]) @@ -375,7 +455,9 @@ def erode_sgn_seg( new_subset = erode_subset(table.copy(), iterations=iterations, threshold=threshold, min_cells=min_cells, keyword=keyword) + eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions) + # Label connected components labeled, num_features = label(eroded_arr) @@ -387,7 +469,7 @@ def erode_sgn_seg( largest_component = (labeled == largest_label).astype(np.uint8) largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8) - #filter small sizes + # filter small sizes if filter_small_components is not None: for (size, feature) in zip(sizes, range(1, num_features + 1)): if size < filter_small_components: @@ -396,11 +478,12 @@ def erode_sgn_seg( return labeled, largest_component_filtered -def get_components(table: pd.DataFrame, +def get_components( + table: pd.DataFrame, labeled: np.typing.NDArray, scale_factor: float, distance_component: Optional[int] = 0, -) -> list: +) -> List[int]: """Indexing coordinates according to labeled array. Args: @@ -423,7 +506,7 @@ def get_components(table: pd.DataFrame, for label_index, l in enumerate(unique_labels): label_arr = (labeled == l).astype(np.uint8) centroids_binary = coordinates_in_downscaled_blocks(table, label_arr, - scale_factor, distance_component = distance_component) + scale_factor, distance_component=distance_component) for num, c in enumerate(centroids_binary): if c != 0: component_labels[num] = label_index + 1 @@ -431,21 +514,63 @@ def get_components(table: pd.DataFrame, return component_labels -def postprocess_sgn_seg(table: pd.DataFrame, scale_factor: Optional[float] = 20) -> pd.DataFrame: +def component_labels_graph(table: pd.DataFrame) -> List[int]: + """Label components using graph connected components. + + Args: + table: Dataframe of segmentation table. + + Returns: + List of component label for each point in dataframe. + """ + components = erode_sgn_seg_graph(table) + + length_components = [len(c) for c in components] + length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) + + component_labels = [0 for _ in range(len(table))] + for lab, comp in enumerate(components): + for comp_index in comp: + component_labels[comp_index] = lab + 1 + + return component_labels + + +def component_labels_downscaling(table: pd.DataFrame, scale_factor: float = 20) -> List[int]: + """Label components using downscaling and connected components. + + Args: + table: Dataframe of segmentation table. + scale_factor: Factor for downscaling. + + Returns: + List of component label for each point in dataframe. + """ + labeled, largest_component = erode_sgn_seg_downscaling(table, filter_small_components=10, + scale_factor=scale_factor, threshold_erode=None) + component_labels = get_components(table, labeled, scale_factor, distance_component=1) + + return component_labels + + +def postprocess_sgn_seg( + table: pd.DataFrame, + postprocess_type: Optional[str] = "downsampling", +) -> pd.DataFrame: """Postprocessing SGN segmentation of cochlea. Args: table: Dataframe of segmentation table. - scale_factor: Scaling for downsampling. + postprocess_type: Postprocessing method, either 'downsampling' or 'graph'. Returns: Dataframe with component labels. """ - labeled, largest_component = erode_sgn_seg(table, filter_small_labels=10, - scale_factor=scale_factor, threshold_erode=None) - - component_labels = get_components(table, labeled, scale_factor, distance_component = 1) + if postprocess_type == "downsampling": + component_labels = component_labels_downscaling(table) + elif postprocess_type == "graph": + component_labels = component_labels_graph(table) table.loc[:, "component_labels"] = component_labels - return table \ No newline at end of file + return table From c51947d9cc1c4c52632f91dc09ca698d7809f460 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 7 May 2025 16:51:16 +0200 Subject: [PATCH 4/8] Fixed indexing of components --- flamingo_tools/segmentation/postprocessing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 763dc82..6b09f9d 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -290,8 +290,8 @@ def downscaled_centroids( new_array[int(c[0]), int(c[1]), int(c[2])] += 1 elif downsample_mode == "capped": - new_array = np.round(new_array).astype(int) - new_array[new_array >= 1] = 1 + for c in centroids_scaled: + new_array[int(c[0]), int(c[1]), int(c[2])] = 1 elif downsample_mode == "components": if "component_labels" not in table.columns: @@ -300,11 +300,12 @@ def downscaled_centroids( for comp, centr in zip(component_labels, centroids_scaled): if comp != 0: new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp - new_array = np.round(new_array).astype(int) else: raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.") + new_array = np.round(new_array).astype(int) + return new_array @@ -531,7 +532,7 @@ def component_labels_graph(table: pd.DataFrame) -> List[int]: component_labels = [0 for _ in range(len(table))] for lab, comp in enumerate(components): for comp_index in comp: - component_labels[comp_index] = lab + 1 + component_labels[comp_index - 1] = lab + 1 return component_labels From 67193f5fe814a886e3b1dd74e84b916b381dc77e Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 8 May 2025 10:00:11 +0200 Subject: [PATCH 5/8] Removed post-processing with downsampling --- flamingo_tools/segmentation/postprocessing.py | 227 ++++-------------- 1 file changed, 47 insertions(+), 180 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 6b09f9d..86f0373 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -10,9 +10,6 @@ import pandas as pd from elf.io import open_file -from scipy.ndimage import binary_fill_holes -from scipy.ndimage import distance_transform_edt -from scipy.ndimage import label from scipy.sparse import csr_matrix from scipy.spatial import distance from scipy.spatial import cKDTree, ConvexHull @@ -212,11 +209,6 @@ def filter_chunk(block_id): return n_ids, n_ids_filtered -# Postprocess segmentation by erosion using the above spatial statistics. -# Currently implemented using downscaling and looking for connected components -# TODO: Change implementation to graph connected components. - - def erode_subset( table: pd.DataFrame, iterations: Optional[int] = 1, @@ -242,7 +234,6 @@ def erode_subset( for i in range(iterations): table = table[table[keyword] < threshold] - # TODO: support other spatial statistics distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors) if min_cells is not None and len(distance_avg) < min_cells: @@ -309,50 +300,14 @@ def downscaled_centroids( return new_array -def coordinates_in_downscaled_blocks( - table: pd.DataFrame, - down_array: np.typing.NDArray, - scale_factor: float, - distance_component: Optional[int] = 0, -) -> List[int]: - """Checking if coordinates are within the downscaled array. - - Args: - table: Dataframe of segmentation table. - down_array: Downscaled array. - scale_factor: Factor which was used for downscaling. - distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included. - - Returns: - A binary list representing whether the dataframe coordinates are within the array. - """ - # fill holes in down-sampled array - down_array[down_array > 0] = 1 - down_array = binary_fill_holes(down_array).astype(np.uint8) - - # check if input coordinates are within down-sampled blocks - centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) - centroids = [np.floor(np.array([c[0]/scale_factor, c[1]/scale_factor, c[2]/scale_factor])) for c in centroids] - - distance_map = distance_transform_edt(down_array == 0) - - centroids_binary = [] - for c in centroids: - coord = (int(c[0]), int(c[1]), int(c[2])) - if down_array[coord] != 0: - centroids_binary.append(1) - elif distance_map[coord] <= distance_component: - centroids_binary.append(1) - else: - centroids_binary.append(0) - - return centroids_binary - - -def erode_sgn_seg_graph( +def components_sgn( table: pd.DataFrame, keyword: Optional[str] = "distance_nn100", threshold_erode: Optional[float] = None, + postprocess_graph: Optional[bool] = False, + min_component_length: Optional[int] = 50, + min_edge_distance: Optional[float] = 30, + iterations_erode: Optional[int] = None, ) -> List[List[int]]: """Eroding the SGN segmentation. @@ -360,21 +315,28 @@ def erode_sgn_seg_graph( table: Dataframe of segmentation table. keyword: Keyword of the dataframe column for erosion. threshold_erode: Threshold of column value after erosion step with spatial statistics. + postprocess_graph: Post-process graph connected components by searching for near points. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. Returns: Subgraph components as lists of label_ids of dataframe. """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + labels = [int(i) for i in list(table["label_id"])] + print("initial length", len(table)) distance_nn = list(table[keyword]) distance_nn.sort() if len(table) < 20000: - iterations = 1 + iterations = iterations_erode if iterations_erode is not None else 0 min_cells = None average_dist = int(distance_nn[int(len(table) * 0.8)]) threshold = threshold_erode if threshold_erode is not None else average_dist else: - iterations = 15 + iterations = iterations_erode if iterations_erode is not None else 15 min_cells = 20000 threshold = threshold_erode if threshold_erode is not None else 40 @@ -394,142 +356,69 @@ def erode_sgn_seg_graph( for num, pos in coords.items(): graph.add_node(num, pos=pos) - # create edges between points whose distance is less than threshold - threshold = 30 + # create edges between points whose distance is less than threshold min_edge_distance for i in coords: for j in coords: if i < j: dist = math.dist(coords[i], coords[j]) - if dist <= threshold: + if dist <= min_edge_distance: graph.add_edge(i, j, weight=dist) components = list(nx.connected_components(graph)) - # remove connected components with less nodes than threshold - min_length = 100 + # remove connected components with less nodes than threshold min_component_length for component in components: - if len(component) < min_length: + if len(component) < min_component_length: for c in component: graph.remove_node(c) - components = list(nx.connected_components(graph)) + components = [list(s) for s in nx.connected_components(graph)] + + # add original coordinates closer to eroded component than threshold + if postprocess_graph: + threshold = 15 + for label_id, centr in zip(labels, centroids): + if label_id not in labels_subset: + add_coord = [] + for comp_index, component in enumerate(components): + for comp_label in component: + dist = math.dist(centr, centroids[comp_label - 1]) + if dist <= threshold: + add_coord.append([comp_index, label_id]) + break + if len(add_coord) != 0: + components[add_coord[0][0]].append(add_coord[0][1]) return components -def erode_sgn_seg_downscaling( +def label_components( table: pd.DataFrame, - keyword: Optional[str] = "distance_nn100", - filter_small_components: Optional[int] = None, - scale_factor: Optional[float] = 20, threshold_erode: Optional[float] = None, -) -> Tuple[np.typing.NDArray, np.typing.NDArray]: - """Eroding the SGN segmentation. - - Args: - table: Dataframe of segmentation table. - keyword: Keyword of the dataframe column for erosion. - filter_small_components: Filter components smaller after n blocks after labeling. - scale_factor: Scaling for downsampling. - threshold_erode: Threshold of column value after erosion step with spatial statistics. - - Returns: - The labeled components of the downscaled, eroded coordinates. - The larget connected component of the labeled components. - """ - ref_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) - print("initial length", len(table)) - distance_nn = list(table[keyword]) - distance_nn.sort() - - if len(table) < 20000: - iterations = 1 - min_cells = None - average_dist = int(distance_nn[int(len(table) * 0.8)]) - threshold = threshold_erode if threshold_erode is not None else average_dist - else: - iterations = 15 - min_cells = 20000 - threshold = threshold_erode if threshold_erode is not None else 40 - - print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") - - new_subset = erode_subset(table.copy(), iterations=iterations, - threshold=threshold, min_cells=min_cells, keyword=keyword) - - eroded_arr = downscaled_centroids(new_subset, scale_factor=scale_factor, ref_dimensions=ref_dimensions) - - # Label connected components - labeled, num_features = label(eroded_arr) - - # Find the largest component - sizes = [(labeled == i).sum() for i in range(1, num_features + 1)] - largest_label = np.argmax(sizes) + 1 - - # Extract only the largest component - largest_component = (labeled == largest_label).astype(np.uint8) - largest_component_filtered = binary_fill_holes(largest_component).astype(np.uint8) - - # filter small sizes - if filter_small_components is not None: - for (size, feature) in zip(sizes, range(1, num_features + 1)): - if size < filter_small_components: - labeled[labeled == feature] = 0 - - return labeled, largest_component_filtered - - -def get_components( - table: pd.DataFrame, - labeled: np.typing.NDArray, - scale_factor: float, - distance_component: Optional[int] = 0, + min_component_length: Optional[int] = 50, + min_edge_distance: Optional[float] = 30, + iterations_erode: Optional[int] = None, ) -> List[int]: - """Indexing coordinates according to labeled array. - - Args: - table: Dataframe of segmentation table. - labeled: Array containing differently labeled components. - scale_factor: Scaling for downsampling. - distance_component: Distance in downscaled units to which centroids next to downscaled blocks are included. - - Returns: - List of component labels. - """ - unique_labels = list(np.unique(labeled)) - - # sort non-background labels according to size, descending - unique_labels = [i for i in unique_labels if i != 0] - sizes = [(labeled == i).sum() for i in unique_labels] - sizes, unique_labels = zip(*sorted(zip(sizes, unique_labels), reverse=True)) - - component_labels = [0 for _ in range(len(table))] - for label_index, l in enumerate(unique_labels): - label_arr = (labeled == l).astype(np.uint8) - centroids_binary = coordinates_in_downscaled_blocks(table, label_arr, - scale_factor, distance_component=distance_component) - for num, c in enumerate(centroids_binary): - if c != 0: - component_labels[num] = label_index + 1 - - return component_labels - - -def component_labels_graph(table: pd.DataFrame) -> List[int]: """Label components using graph connected components. Args: table: Dataframe of segmentation table. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. Returns: - List of component label for each point in dataframe. + List of component label for each point in dataframe. 0 - background, then in descending order of size """ - components = erode_sgn_seg_graph(table) + components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) length_components = [len(c) for c in components] length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) component_labels = [0 for _ in range(len(table))] + # be aware of 'label_id' of dataframe starting at 1 for lab, comp in enumerate(components): for comp_index in comp: component_labels[comp_index - 1] = lab + 1 @@ -537,40 +426,18 @@ def component_labels_graph(table: pd.DataFrame) -> List[int]: return component_labels -def component_labels_downscaling(table: pd.DataFrame, scale_factor: float = 20) -> List[int]: - """Label components using downscaling and connected components. - - Args: - table: Dataframe of segmentation table. - scale_factor: Factor for downscaling. - - Returns: - List of component label for each point in dataframe. - """ - labeled, largest_component = erode_sgn_seg_downscaling(table, filter_small_components=10, - scale_factor=scale_factor, threshold_erode=None) - component_labels = get_components(table, labeled, scale_factor, distance_component=1) - - return component_labels - - def postprocess_sgn_seg( table: pd.DataFrame, - postprocess_type: Optional[str] = "downsampling", ) -> pd.DataFrame: """Postprocessing SGN segmentation of cochlea. Args: table: Dataframe of segmentation table. - postprocess_type: Postprocessing method, either 'downsampling' or 'graph'. Returns: Dataframe with component labels. """ - if postprocess_type == "downsampling": - component_labels = component_labels_downscaling(table) - elif postprocess_type == "graph": - component_labels = component_labels_graph(table) + component_labels = label_components(table) table.loc[:, "component_labels"] = component_labels From bf580718bbdc3dd63f7633f8bda88d4cab3cdfc2 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 8 May 2025 17:02:18 +0200 Subject: [PATCH 6/8] Updated script --- flamingo_tools/segmentation/postprocessing.py | 33 ++++- scripts/prediction/postprocess_seg.py | 128 ++++++++++++------ 2 files changed, 113 insertions(+), 48 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 86f0373..50d22ee 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -222,7 +222,7 @@ def erode_subset( Args: table: Dataframe of segmentation table. iterations: Number of steps for erosion process. - min_cells: Minimal number of rows. The erosion is stopped before reaching this number. + min_cells: Minimal number of rows. The erosion is stopped after falling below this limit. threshold: Upper threshold for removing elements according to the given keyword. keyword: Keyword of dataframe for erosion. @@ -259,7 +259,7 @@ def downscaled_centroids( table: Dataframe of segmentation table. scale_factor: Factor for downscaling coordinates. ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. - downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components' + downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'. Returns: The downscaled array @@ -326,7 +326,6 @@ def components_sgn( centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) labels = [int(i) for i in list(table["label_id"])] - print("initial length", len(table)) distance_nn = list(table[keyword]) distance_nn.sort() @@ -394,6 +393,7 @@ def components_sgn( def label_components( table: pd.DataFrame, + min_size: Optional[int] = 1000, threshold_erode: Optional[float] = None, min_component_length: Optional[int] = 50, min_edge_distance: Optional[float] = 30, @@ -403,6 +403,7 @@ def label_components( Args: table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. @@ -411,9 +412,18 @@ def label_components( Returns: List of component label for each point in dataframe. 0 - background, then in descending order of size """ + + # First, apply the size filter. + entries_filtered = table[table.n_pixels < min_size] + table = table[table.n_pixels >= min_size] + components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + # add size-filtered objects to have same initial length + table = pd.concat([table, entries_filtered], ignore_index=True) + table.sort_values("label_id") + length_components = [len(c) for c in components] length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) @@ -428,17 +438,30 @@ def label_components( def postprocess_sgn_seg( table: pd.DataFrame, + min_size: Optional[int] = 1000, + threshold_erode: Optional[float] = None, + min_component_length: Optional[int] = 50, + min_edge_distance: Optional[float] = 30, + iterations_erode: Optional[int] = None, ) -> pd.DataFrame: """Postprocessing SGN segmentation of cochlea. Args: table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. Returns: Dataframe with component labels. """ - component_labels = label_components(table) - table.loc[:, "component_labels"] = component_labels + comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode, + min_component_length=min_component_length, + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + + table.loc[:, "component_labels"] = comp_labels return table diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 0134539..66735ab 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -7,6 +7,7 @@ import flamingo_tools.s3_utils as s3_utils from flamingo_tools.segmentation import filter_segmentation from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius +from flamingo_tools.segmentation.postprocessing import postprocess_sgn_seg # TODO needs updates @@ -15,10 +16,13 @@ def main(): parser = argparse.ArgumentParser( description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") - parser.add_argument("-o", "--output_folder", type=str, required=True) + parser.add_argument("-o", "--output_folder", type=str, default=None) parser.add_argument("-t", "--tsv", type=str, default=None, help="TSV-file in MoBIE format which contains information about segmentation.") + parser.add_argument("--tsv_out", type=str, default=None, + help="File path to save post-processed dataframe. Default: default.tsv") + parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="The key / internal path of the segmentation.") parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", @@ -26,7 +30,20 @@ def main(): parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of segmentation in micrometer.") - parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.") + # options for post-processing + parser.add_argument("--min_size", type=int, default=1000, + help="Minimal number of pixels for filtering small instances.") + parser.add_argument("--threshold", type=float, default=None, + help="Threshold for spatial statistics.") + parser.add_argument("--min_component_length", type=int, default=50, + help="Minimal length for filtering out connected components.") + parser.add_argument("--min_edge_dist", type=float, default=30, + help="Minimal distance in micrometer between points to create edges for connected components.") + parser.add_argument("--iterations_erode", type=int, default=None, + help="Number of iterations for erosion, normally determined automatically.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") @@ -35,23 +52,42 @@ def main(): parser.add_argument("--s3_service_endpoint", type=str, default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") - parser.add_argument("--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") - + # options for spatial statistics parser.add_argument("--n_neighbors", type=int, default=None, help="Value for calculating distance to 'n' nearest neighbors.") - parser.add_argument("--local_ripley_radius", type=int, default=None, help="Value for radius for calculating local Ripley's K function.") - parser.add_argument("--r_neighbors", type=int, default=None, help="Value for radius for calculating number of neighbors in range.") args = parser.parse_args() + if args.output_folder is None and args.tsv is None: + raise ValueError("Either supply an output folder containing 'segmentation.zarr' or a TSV-file in MoBIE format.") + + # check output folder + if args.output_folder is not None: + seg_path = os.path.join(args.output_folder, "segmentation.zarr") + if args.s3: + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + with zarr.open(s3_path, mode="r") as f: + segmentation = f[args.input_key] + else: + with zarr.open(seg_path, mode="r") as f: + segmentation = f[args.input_key] + else: + seg_path = None + + # check input for spatial statistics postprocess_functions = [nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius] function_keywords = ["n_neighbors", "radius", "radius"] postprocess_options = [args.n_neighbors, args.local_ripley_radius, args.r_neighbors] - default_thresholds = [15, 20, 20] + default_thresholds = [args.threshold for _ in postprocess_functions] + + if seg_path is not None and args.threshold is None: + default_thresholds = [15, 20, 20] def create_spatial_statistics_dict(functions, keyword, options, threshold): spatial_statistics_dict = [] @@ -62,52 +98,58 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold): spatial_statistics_dict = create_spatial_statistics_dict(postprocess_functions, postprocess_options, function_keywords, default_thresholds) - - if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: - raise ValueError("Choose a postprocess function from 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") - elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: - raise ValueError("The script only supports a single postprocess function.") - else: - for d in spatial_statistics_dict: - if d["argument"] is not None: - spatial_statistics = d["function"] - spatial_statistics_kwargs = {d["keyword"]: d["argument"]} - threshold = d["threshold"] - - seg_path = os.path.join(args.output_folder, "segmentation.zarr") - + if seg_path is not None: + if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: + raise ValueError("Choose a postprocess function: 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") + elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: + raise ValueError("The script only supports a single postprocess function.") + else: + for d in spatial_statistics_dict: + if d["argument"] is not None: + spatial_statistics = d["function"] + spatial_statistics_kwargs = {d["keyword"]: d["argument"]} + threshold = d["threshold"] + + # check TSV-file containing data in MoBIE format tsv_table = None - - if args.s3_input is not None: - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, - service_endpoint=args.s3_service_endpoint, - credential_file=args.s3_credentials) - with zarr.open(s3_path, mode="r") as f: - segmentation = f[args.input_key] - - if args.tsv is not None: + if args.tsv is not None: + if args.s3: tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") - - else: - with zarr.open(seg_path, mode="r") as f: - segmentation = f[args.input_key] - - if args.tsv is not None: + else: with open(args.tsv, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") - n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, - spatial_statistics=spatial_statistics, - threshold=threshold, - min_size=args.min_size, table=tsv_table, - resolution=args.resolution, - output_key=args.output_key, **spatial_statistics_kwargs) + if seg_path is None: + post_table = postprocess_sgn_seg( + tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold, + min_component_length=args.min_component_length, min_edge_distance=args.min_edge_dist, + iterations_erode=args.iterations_erode, + ) + + if args.tsv_out is None: + out_path = "default.tsv" + else: + out_path = args.tsv_out + post_table.to_csv(out_path, sep="\t", index=False) + + n_pre = len(tsv_table) + n_post = len(post_table["component_labels"][post_table["component_labels"] == 1]) - print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") + print(f"Number of pre-filtered objects: {n_pre}\nNumber of objects in largest component: {n_post}") + + else: + n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, + spatial_statistics=spatial_statistics, + threshold=threshold, + min_size=args.min_size, table=tsv_table, + resolution=args.resolution, + output_key=args.output_key, **spatial_statistics_kwargs) + + print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") if __name__ == "__main__": From 261f255dd7e208633cb4d12ccc902b65cf62a9e4 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Tue, 13 May 2025 08:49:19 +0200 Subject: [PATCH 7/8] Fixed Optional syntax --- flamingo_tools/segmentation/postprocessing.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 50d22ee..75529b6 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -211,10 +211,10 @@ def filter_chunk(block_id): def erode_subset( table: pd.DataFrame, - iterations: Optional[int] = 1, + iterations: int = 1, min_cells: Optional[int] = None, - threshold: Optional[int] = 35, - keyword: Optional[str] = "distance_nn100", + threshold: int = 35, + keyword: str = "distance_nn100", ) -> pd.DataFrame: """Erode coordinates of dataframe according to a keyword and a threshold. Use a copy of the dataframe as an input, if it should not be edited. @@ -251,7 +251,7 @@ def downscaled_centroids( table: pd.DataFrame, scale_factor: int, ref_dimensions: Optional[Tuple[float, float, float]] = None, - downsample_mode: Optional[str] = "accumulated", + downsample_mode: str = "accumulated", ) -> np.typing.NDArray: """Downscale centroids in dataframe. @@ -302,11 +302,11 @@ def downscaled_centroids( def components_sgn( table: pd.DataFrame, - keyword: Optional[str] = "distance_nn100", + keyword: str = "distance_nn100", threshold_erode: Optional[float] = None, - postprocess_graph: Optional[bool] = False, - min_component_length: Optional[int] = 50, - min_edge_distance: Optional[float] = 30, + postprocess_graph: bool = False, + min_component_length: int = 50, + min_edge_distance: float = 30, iterations_erode: Optional[int] = None, ) -> List[List[int]]: """Eroding the SGN segmentation. @@ -393,10 +393,10 @@ def components_sgn( def label_components( table: pd.DataFrame, - min_size: Optional[int] = 1000, + min_size: int = 1000, threshold_erode: Optional[float] = None, - min_component_length: Optional[int] = 50, - min_edge_distance: Optional[float] = 30, + min_component_length: int = 50, + min_edge_distance: float = 30, iterations_erode: Optional[int] = None, ) -> List[int]: """Label components using graph connected components. @@ -438,10 +438,10 @@ def label_components( def postprocess_sgn_seg( table: pd.DataFrame, - min_size: Optional[int] = 1000, + min_size: int = 1000, threshold_erode: Optional[float] = None, - min_component_length: Optional[int] = 50, - min_edge_distance: Optional[float] = 30, + min_component_length: int = 50, + min_edge_distance: float = 30, iterations_erode: Optional[int] = None, ) -> pd.DataFrame: """Postprocessing SGN segmentation of cochlea. From 6c4e504b9d69b1ef2aa6ed6b0874c24e2ab285c3 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 13 May 2025 12:56:00 +0200 Subject: [PATCH 8/8] Fix errors caused by zarr --- environment.yaml | 3 ++- flamingo_tools/file_utils.py | 7 ++++++- flamingo_tools/s3_utils.py | 7 ++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/environment.yaml b/environment.yaml index 2aedfb7..05b23c6 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,4 +13,5 @@ dependencies: - s3fs - torch_em - z5py - - zarr + # Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet + - zarr <3 diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index f121923..acfcdca 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -7,6 +7,11 @@ import zarr from elf.io import open_file +try: + from zarr.abc.store import Store +except ImportError: + from zarr._storage.store import BaseStore as Store + def _parse_shape(metadata_file): depth, height, width = None, None, None @@ -62,7 +67,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: return x -def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike: +def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike: """Read flamingo image data, stored in various formats. Args: diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index 4dc6a23..7e8aa27 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -7,6 +7,11 @@ import s3fs import zarr +try: + from zarr.abc.store import Store +except ImportError: + from zarr._storage.store import BaseStore as Store + # Dedicated bucket for cochlea lightsheet project MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" @@ -93,7 +98,7 @@ def get_s3_path( bucket_name: Optional[str] = None, service_endpoint: Optional[str] = None, credential_file: Optional[str] = None, -) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]: +) -> Tuple[Store, s3fs.core.S3FileSystem]: """Get S3 path for a file or folder and file system based on S3 parameters and credentials. Args: