diff --git a/environment.yaml b/environment.yaml index 2aedfb7..05b23c6 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,4 +13,5 @@ dependencies: - s3fs - torch_em - z5py - - zarr + # Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet + - zarr <3 diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index f121923..acfcdca 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -7,6 +7,11 @@ import zarr from elf.io import open_file +try: + from zarr.abc.store import Store +except ImportError: + from zarr._storage.store import BaseStore as Store + def _parse_shape(metadata_file): depth, height, width = None, None, None @@ -62,7 +67,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: return x -def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike: +def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> np.typing.ArrayLike: """Read flamingo image data, stored in various formats. Args: diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index 4dc6a23..7e8aa27 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -7,6 +7,11 @@ import s3fs import zarr +try: + from zarr.abc.store import Store +except ImportError: + from zarr._storage.store import BaseStore as Store + # Dedicated bucket for cochlea lightsheet project MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" @@ -93,7 +98,7 @@ def get_s3_path( bucket_name: Optional[str] = None, service_endpoint: Optional[str] = None, credential_file: Optional[str] = None, -) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]: +) -> Tuple[Store, s3fs.core.S3FileSystem]: """Get S3 path for a file or folder and file system based on S3 parameters and credentials. Args: diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 7ad987b..75529b6 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,15 +1,17 @@ +import math import multiprocessing as mp from concurrent import futures -from typing import Callable, Tuple, Optional +from typing import Callable, List, Optional, Tuple import elf.parallel as parallel import numpy as np import nifty.tools as nt +import networkx as nx import pandas as pd from elf.io import open_file -from scipy.spatial import distance from scipy.sparse import csr_matrix +from scipy.spatial import distance from scipy.spatial import cKDTree, ConvexHull from skimage import measure from sklearn.neighbors import NearestNeighbors @@ -205,3 +207,261 @@ def filter_chunk(block_id): ) return n_ids, n_ids_filtered + + +def erode_subset( + table: pd.DataFrame, + iterations: int = 1, + min_cells: Optional[int] = None, + threshold: int = 35, + keyword: str = "distance_nn100", +) -> pd.DataFrame: + """Erode coordinates of dataframe according to a keyword and a threshold. + Use a copy of the dataframe as an input, if it should not be edited. + + Args: + table: Dataframe of segmentation table. + iterations: Number of steps for erosion process. + min_cells: Minimal number of rows. The erosion is stopped after falling below this limit. + threshold: Upper threshold for removing elements according to the given keyword. + keyword: Keyword of dataframe for erosion. + + Returns: + The dataframe containing elements left after the erosion. + """ + print("initial length", len(table)) + n_neighbors = 100 + for i in range(iterations): + table = table[table[keyword] < threshold] + + distance_avg = nearest_neighbor_distance(table, n_neighbors=n_neighbors) + + if min_cells is not None and len(distance_avg) < min_cells: + print(f"{i}-th iteration, length of subset {len(table)}, stopping erosion") + break + + table.loc[:, 'distance_nn'+str(n_neighbors)] = list(distance_avg) + + print(f"{i}-th iteration, length of subset {len(table)}") + + return table + + +def downscaled_centroids( + table: pd.DataFrame, + scale_factor: int, + ref_dimensions: Optional[Tuple[float, float, float]] = None, + downsample_mode: str = "accumulated", +) -> np.typing.NDArray: + """Downscale centroids in dataframe. + + Args: + table: Dataframe of segmentation table. + scale_factor: Factor for downscaling coordinates. + ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. + downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'. + + Returns: + The downscaled array + """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids] + + if ref_dimensions is None: + bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions]) + new_array = np.zeros(bounding_dimensions_scaled) + + else: + bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in ref_dimensions]) + new_array = np.zeros(bounding_dimensions_scaled) + + if downsample_mode == "accumulated": + for c in centroids_scaled: + new_array[int(c[0]), int(c[1]), int(c[2])] += 1 + + elif downsample_mode == "capped": + for c in centroids_scaled: + new_array[int(c[0]), int(c[1]), int(c[2])] = 1 + + elif downsample_mode == "components": + if "component_labels" not in table.columns: + raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.") + component_labels = list(table["component_labels"]) + for comp, centr in zip(component_labels, centroids_scaled): + if comp != 0: + new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp + + else: + raise ValueError("Choose one of the downsampling modes 'accumulated', 'capped', or 'components'.") + + new_array = np.round(new_array).astype(int) + + return new_array + + +def components_sgn( + table: pd.DataFrame, + keyword: str = "distance_nn100", + threshold_erode: Optional[float] = None, + postprocess_graph: bool = False, + min_component_length: int = 50, + min_edge_distance: float = 30, + iterations_erode: Optional[int] = None, +) -> List[List[int]]: + """Eroding the SGN segmentation. + + Args: + table: Dataframe of segmentation table. + keyword: Keyword of the dataframe column for erosion. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + postprocess_graph: Post-process graph connected components by searching for near points. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. + + Returns: + Subgraph components as lists of label_ids of dataframe. + """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + labels = [int(i) for i in list(table["label_id"])] + + distance_nn = list(table[keyword]) + distance_nn.sort() + + if len(table) < 20000: + iterations = iterations_erode if iterations_erode is not None else 0 + min_cells = None + average_dist = int(distance_nn[int(len(table) * 0.8)]) + threshold = threshold_erode if threshold_erode is not None else average_dist + else: + iterations = iterations_erode if iterations_erode is not None else 15 + min_cells = 20000 + threshold = threshold_erode if threshold_erode is not None else 40 + + print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") + + new_subset = erode_subset(table.copy(), iterations=iterations, + threshold=threshold, min_cells=min_cells, keyword=keyword) + + # create graph from coordinates of eroded subset + centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) + labels_subset = [int(i) for i in list(new_subset["label_id"])] + coords = {} + for index, element in zip(labels_subset, centroids_subset): + coords[index] = element + + graph = nx.Graph() + for num, pos in coords.items(): + graph.add_node(num, pos=pos) + + # create edges between points whose distance is less than threshold min_edge_distance + for i in coords: + for j in coords: + if i < j: + dist = math.dist(coords[i], coords[j]) + if dist <= min_edge_distance: + graph.add_edge(i, j, weight=dist) + + components = list(nx.connected_components(graph)) + + # remove connected components with less nodes than threshold min_component_length + for component in components: + if len(component) < min_component_length: + for c in component: + graph.remove_node(c) + + components = [list(s) for s in nx.connected_components(graph)] + + # add original coordinates closer to eroded component than threshold + if postprocess_graph: + threshold = 15 + for label_id, centr in zip(labels, centroids): + if label_id not in labels_subset: + add_coord = [] + for comp_index, component in enumerate(components): + for comp_label in component: + dist = math.dist(centr, centroids[comp_label - 1]) + if dist <= threshold: + add_coord.append([comp_index, label_id]) + break + if len(add_coord) != 0: + components[add_coord[0][0]].append(add_coord[0][1]) + + return components + + +def label_components( + table: pd.DataFrame, + min_size: int = 1000, + threshold_erode: Optional[float] = None, + min_component_length: int = 50, + min_edge_distance: float = 30, + iterations_erode: Optional[int] = None, +) -> List[int]: + """Label components using graph connected components. + + Args: + table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. + + Returns: + List of component label for each point in dataframe. 0 - background, then in descending order of size + """ + + # First, apply the size filter. + entries_filtered = table[table.n_pixels < min_size] + table = table[table.n_pixels >= min_size] + + components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + + # add size-filtered objects to have same initial length + table = pd.concat([table, entries_filtered], ignore_index=True) + table.sort_values("label_id") + + length_components = [len(c) for c in components] + length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) + + component_labels = [0 for _ in range(len(table))] + # be aware of 'label_id' of dataframe starting at 1 + for lab, comp in enumerate(components): + for comp_index in comp: + component_labels[comp_index - 1] = lab + 1 + + return component_labels + + +def postprocess_sgn_seg( + table: pd.DataFrame, + min_size: int = 1000, + threshold_erode: Optional[float] = None, + min_component_length: int = 50, + min_edge_distance: float = 30, + iterations_erode: Optional[int] = None, +) -> pd.DataFrame: + """Postprocessing SGN segmentation of cochlea. + + Args: + table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. + threshold_erode: Threshold of column value after erosion step with spatial statistics. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + iterations_erode: Number of iterations for erosion, normally determined automatically. + + Returns: + Dataframe with component labels. + """ + + comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode, + min_component_length=min_component_length, + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + + table.loc[:, "component_labels"] = comp_labels + + return table diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py index 1b6d3f7..b8a71ff 100644 --- a/scripts/prediction/expand_seg_table.py +++ b/scripts/prediction/expand_seg_table.py @@ -61,7 +61,7 @@ def main( neighbor_counts = [n[0] for n in neighbor_counts] tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts - tsv_table.to_csv(out_path, sep="\t") + tsv_table.to_csv(out_path, sep="\t", index=False) if __name__ == "__main__": diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 0134539..66735ab 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -7,6 +7,7 @@ import flamingo_tools.s3_utils as s3_utils from flamingo_tools.segmentation import filter_segmentation from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius +from flamingo_tools.segmentation.postprocessing import postprocess_sgn_seg # TODO needs updates @@ -15,10 +16,13 @@ def main(): parser = argparse.ArgumentParser( description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") - parser.add_argument("-o", "--output_folder", type=str, required=True) + parser.add_argument("-o", "--output_folder", type=str, default=None) parser.add_argument("-t", "--tsv", type=str, default=None, help="TSV-file in MoBIE format which contains information about segmentation.") + parser.add_argument("--tsv_out", type=str, default=None, + help="File path to save post-processed dataframe. Default: default.tsv") + parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="The key / internal path of the segmentation.") parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", @@ -26,7 +30,20 @@ def main(): parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of segmentation in micrometer.") - parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.") + # options for post-processing + parser.add_argument("--min_size", type=int, default=1000, + help="Minimal number of pixels for filtering small instances.") + parser.add_argument("--threshold", type=float, default=None, + help="Threshold for spatial statistics.") + parser.add_argument("--min_component_length", type=int, default=50, + help="Minimal length for filtering out connected components.") + parser.add_argument("--min_edge_dist", type=float, default=30, + help="Minimal distance in micrometer between points to create edges for connected components.") + parser.add_argument("--iterations_erode", type=int, default=None, + help="Number of iterations for erosion, normally determined automatically.") + + # options for S3 bucket + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") @@ -35,23 +52,42 @@ def main(): parser.add_argument("--s3_service_endpoint", type=str, default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") - parser.add_argument("--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") - + # options for spatial statistics parser.add_argument("--n_neighbors", type=int, default=None, help="Value for calculating distance to 'n' nearest neighbors.") - parser.add_argument("--local_ripley_radius", type=int, default=None, help="Value for radius for calculating local Ripley's K function.") - parser.add_argument("--r_neighbors", type=int, default=None, help="Value for radius for calculating number of neighbors in range.") args = parser.parse_args() + if args.output_folder is None and args.tsv is None: + raise ValueError("Either supply an output folder containing 'segmentation.zarr' or a TSV-file in MoBIE format.") + + # check output folder + if args.output_folder is not None: + seg_path = os.path.join(args.output_folder, "segmentation.zarr") + if args.s3: + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + with zarr.open(s3_path, mode="r") as f: + segmentation = f[args.input_key] + else: + with zarr.open(seg_path, mode="r") as f: + segmentation = f[args.input_key] + else: + seg_path = None + + # check input for spatial statistics postprocess_functions = [nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius] function_keywords = ["n_neighbors", "radius", "radius"] postprocess_options = [args.n_neighbors, args.local_ripley_radius, args.r_neighbors] - default_thresholds = [15, 20, 20] + default_thresholds = [args.threshold for _ in postprocess_functions] + + if seg_path is not None and args.threshold is None: + default_thresholds = [15, 20, 20] def create_spatial_statistics_dict(functions, keyword, options, threshold): spatial_statistics_dict = [] @@ -62,52 +98,58 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold): spatial_statistics_dict = create_spatial_statistics_dict(postprocess_functions, postprocess_options, function_keywords, default_thresholds) - - if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: - raise ValueError("Choose a postprocess function from 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") - elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: - raise ValueError("The script only supports a single postprocess function.") - else: - for d in spatial_statistics_dict: - if d["argument"] is not None: - spatial_statistics = d["function"] - spatial_statistics_kwargs = {d["keyword"]: d["argument"]} - threshold = d["threshold"] - - seg_path = os.path.join(args.output_folder, "segmentation.zarr") - + if seg_path is not None: + if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: + raise ValueError("Choose a postprocess function: 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") + elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: + raise ValueError("The script only supports a single postprocess function.") + else: + for d in spatial_statistics_dict: + if d["argument"] is not None: + spatial_statistics = d["function"] + spatial_statistics_kwargs = {d["keyword"]: d["argument"]} + threshold = d["threshold"] + + # check TSV-file containing data in MoBIE format tsv_table = None - - if args.s3_input is not None: - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, - service_endpoint=args.s3_service_endpoint, - credential_file=args.s3_credentials) - with zarr.open(s3_path, mode="r") as f: - segmentation = f[args.input_key] - - if args.tsv is not None: + if args.tsv is not None: + if args.s3: tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") - - else: - with zarr.open(seg_path, mode="r") as f: - segmentation = f[args.input_key] - - if args.tsv is not None: + else: with open(args.tsv, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") - n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, - spatial_statistics=spatial_statistics, - threshold=threshold, - min_size=args.min_size, table=tsv_table, - resolution=args.resolution, - output_key=args.output_key, **spatial_statistics_kwargs) + if seg_path is None: + post_table = postprocess_sgn_seg( + tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold, + min_component_length=args.min_component_length, min_edge_distance=args.min_edge_dist, + iterations_erode=args.iterations_erode, + ) + + if args.tsv_out is None: + out_path = "default.tsv" + else: + out_path = args.tsv_out + post_table.to_csv(out_path, sep="\t", index=False) + + n_pre = len(tsv_table) + n_post = len(post_table["component_labels"][post_table["component_labels"] == 1]) - print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") + print(f"Number of pre-filtered objects: {n_pre}\nNumber of objects in largest component: {n_post}") + + else: + n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, + spatial_statistics=spatial_statistics, + threshold=threshold, + min_size=args.min_size, table=tsv_table, + resolution=args.resolution, + output_key=args.output_key, **spatial_statistics_kwargs) + + print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") if __name__ == "__main__":