diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 3b24ce9..d1bf25b 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -319,6 +319,41 @@ def downscaled_centroids( return new_array +def graph_connected_components(coords: dict, min_edge_distance: float, min_component_length: int): + """Create a list of IDs for each connected component of a graph. + + Args: + coords: Dictionary containing label IDs as keys and their position as value. + min_edge_distance: Minimal edge distance between graph nodes to create an edge between nodes. + min_component_length: Minimal length of nodes of connected component. Filtered out if lower. + + Returns: + List of dictionary keys of connected components. + """ + graph = nx.Graph() + for num, pos in coords.items(): + graph.add_node(num, pos=pos) + + # create edges between points whose distance is less than threshold min_edge_distance + for num_i, pos_i in coords.items(): + for num_j, pos_j in coords.items(): + if num_i < num_j: + dist = math.dist(pos_i, pos_j) + if dist <= min_edge_distance: + graph.add_edge(num_i, num_j, weight=dist) + + components = list(nx.connected_components(graph)) + + # remove connected components with less nodes than threshold min_component_length + for component in components: + if len(component) < min_component_length: + for c in component: + graph.remove_node(c) + + components = [list(s) for s in nx.connected_components(graph)] + return components + + def components_sgn( table: pd.DataFrame, keyword: str = "distance_nn100", @@ -370,27 +405,7 @@ def components_sgn( for index, element in zip(labels_subset, centroids_subset): coords[index] = element - graph = nx.Graph() - for num, pos in coords.items(): - graph.add_node(num, pos=pos) - - # create edges between points whose distance is less than threshold min_edge_distance - for i in coords: - for j in coords: - if i < j: - dist = math.dist(coords[i], coords[j]) - if dist <= min_edge_distance: - graph.add_edge(i, j, weight=dist) - - components = list(nx.connected_components(graph)) - - # remove connected components with less nodes than threshold min_component_length - for component in components: - if len(component) < min_component_length: - for c in component: - graph.remove_node(c) - - components = [list(s) for s in nx.connected_components(graph)] + components = graph_connected_components(coords, min_edge_distance, min_component_length) # add original coordinates closer to eroded component than threshold if postprocess_graph: @@ -410,7 +425,7 @@ def components_sgn( return components -def label_components( +def label_components_sgn( table: pd.DataFrame, min_size: int = 1000, threshold_erode: Optional[float] = None, @@ -418,7 +433,7 @@ def label_components( min_edge_distance: float = 30, iterations_erode: Optional[int] = None, ) -> List[int]: - """Label components using graph connected components. + """Label SGN components using graph connected components. Args: table: Dataframe of segmentation table. @@ -477,9 +492,102 @@ def postprocess_sgn_seg( Dataframe with component labels. """ - comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode, - min_component_length=min_component_length, - min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode, + min_component_length=min_component_length, + min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + + table.loc[:, "component_labels"] = comp_labels + + return table + + +def components_ihc( + table: pd.DataFrame, + min_component_length: int = 50, + min_edge_distance: float = 30, +): + """Create connected components for IHC segmentation. + + Args: + table: Dataframe of segmentation table. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + + Returns: + Subgraph components as lists of label_ids of dataframe. + """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + labels = [int(i) for i in list(table["label_id"])] + coords = {} + for index, element in zip(labels, centroids): + coords[index] = element + + components = graph_connected_components(coords, min_edge_distance, min_component_length) + return components + + +def label_components_ihc( + table: pd.DataFrame, + min_size: int = 1000, + min_component_length: int = 50, + min_edge_distance: float = 30, +) -> List[int]: + """Label components using graph connected components. + + Args: + table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + + Returns: + List of component label for each point in dataframe. 0 - background, then in descending order of size + """ + + # First, apply the size filter. + entries_filtered = table[table.n_pixels < min_size] + table = table[table.n_pixels >= min_size] + + components = components_ihc(table, min_component_length=min_component_length, + min_edge_distance=min_edge_distance) + + # add size-filtered objects to have same initial length + table = pd.concat([table, entries_filtered], ignore_index=True) + table.sort_values("label_id") + + length_components = [len(c) for c in components] + length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) + + component_labels = [0 for _ in range(len(table))] + # be aware of 'label_id' of dataframe starting at 1 + for lab, comp in enumerate(components): + for comp_index in comp: + component_labels[comp_index - 1] = lab + 1 + + return component_labels + + +def postprocess_ihc_seg( + table: pd.DataFrame, + min_size: int = 1000, + min_component_length: int = 50, + min_edge_distance: float = 30, +) -> pd.DataFrame: + """Postprocessing IHC segmentation of cochlea. + + Args: + table: Dataframe of segmentation table. + min_size: Minimal number of pixels for filtering small instances. + min_component_length: Minimal length for filtering out connected components. + min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + + Returns: + Dataframe with component labels. + """ + + comp_labels = label_components_ihc(table, min_size=min_size, + min_component_length=min_component_length, + min_edge_distance=min_edge_distance) table.loc[:, "component_labels"] = comp_labels diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 6759cc2..99a0048 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -149,7 +149,9 @@ def postprocess(x): blocking = nt.blocking([0] * ndim, shape, block_shape) n_blocks = blocking.numberOfBlocks if prediction_instances != 1: - iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)] + # shuffle indexes with fixed seed to balance out segmentation blocks for slurm workers + rng = np.random.default_rng(seed=1234) + iteration_ids = [x.tolist() for x in np.array_split(list(rng.permutation(n_blocks)), prediction_instances)] slurm_iteration = iteration_ids[slurm_task_id] else: slurm_iteration = list(range(n_blocks)) @@ -175,7 +177,7 @@ def postprocess(x): return original_shape -def find_mask(input_path: str, input_key: Optional[str], output_folder: str) -> None: +def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None: """Determine the mask for running prediction. The mask corresponds to data that contains actual signal and not just noise. @@ -187,10 +189,25 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str) -> input_path: The file path to the image data. input_key: The key / internal path of the image data. output_folder: The output folder for storing the mask data. + seg_class: Specifier for exclusion criterias for mask generation. """ mask_path = os.path.join(output_folder, "mask.zarr") f = z5py.File(mask_path, "a") + # set parameters for the exclusion of chunks within mask generation + if seg_class == "sgn": + upper_percentile = 95 + min_intensity = 200 + print(f"Calculating mask for segmentation class {seg_class}.") + elif seg_class == "ihc": + upper_percentile = 99 + min_intensity = 150 + print(f"Calculating mask for segmentation class {seg_class}.") + else: + upper_percentile = 95 + min_intensity = 200 + print("Calculating mask with default values.") + mask_key = "mask" if mask_key in f: return @@ -209,8 +226,8 @@ def find_mask_block(block_id): block = blocking.getBlock(block_id) bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) data = raw[bb] - max_ = np.percentile(data, 95) - if max_ > 200: + max_ = np.percentile(data, upper_percentile) + if max_ > min_intensity: ds_mask[bb] = 1 n_threads = min(16, mp.cpu_count()) @@ -359,6 +376,7 @@ def run_unet_prediction( center_distance_threshold: float = 0.4, boundary_distance_threshold: Optional[float] = None, fg_threshold: float = 0.5, + seg_class: Optional[str] = None, ) -> None: """Run prediction and segmentation with a distance U-Net. @@ -377,12 +395,12 @@ def run_unet_prediction( boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds. By default this is set to 'None', in which case the boundary distances are not used for the seeds. fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. + seg_class: Specifier for exclusion criterias for mask generation. """ os.makedirs(output_folder, exist_ok=True) if use_mask: - find_mask(input_path, input_key, output_folder) - + find_mask(input_path, input_key, output_folder, seg_class=seg_class) original_shape = prediction_impl( input_path, input_key, output_folder, model_path, scale, block_shape, halo ) @@ -403,12 +421,13 @@ def run_unet_prediction( def run_unet_prediction_preprocess_slurm( input_path: str, - input_key: Optional[str], output_folder: str, + input_key: Optional[str] = None, s3: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, s3_credentials: Optional[str] = None, + seg_class: Optional[str] = None, ) -> None: """Pre-processing for the parallel prediction with U-Net models. Masks are stored in mask.zarr in the output folder. @@ -417,12 +436,13 @@ def run_unet_prediction_preprocess_slurm( Args: input_path: The path to the input data. - input_key: The key / internal path of the image data. output_folder: The output folder for storing the segmentation related data. + input_key: The key / internal path of the image data. s3: Flag for considering input_path fo S3 bucket. s3_bucket_name: S3 bucket name. s3_service_endpoint: S3 service endpoint. s3_credentials: File path to credentials for S3 bucket. + seg_class: Specifier for exclusion criterias for mask generation. """ if s3 is not None: input_path, fs = s3_utils.get_s3_path( @@ -430,16 +450,17 @@ def run_unet_prediction_preprocess_slurm( ) if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): - find_mask(input_path, input_key, output_folder) + find_mask(input_path, input_key, output_folder, seg_class=seg_class) - calc_mean_and_std(input_path, input_key, output_folder) + if not os.path.isfile(os.path.join(output_folder, "mean_std.json")): + calc_mean_and_std(input_path, input_key, output_folder) def run_unet_prediction_slurm( input_path: str, - input_key: Optional[str], output_folder: str, model_path: str, + input_key: Optional[str] = None, scale: Optional[float] = None, block_shape: Optional[Tuple[int, int, int]] = None, halo: Optional[Tuple[int, int, int]] = None, @@ -453,9 +474,9 @@ def run_unet_prediction_slurm( Args: input_path: The path to the input data. - input_key: The key / internal path of the image data. output_folder: The output folder for storing the segmentation related data. model_path: The path to the model to use for segmentation. + input_key: The key / internal path of the image data. scale: A factor to rescale the data before prediction. By default the data will not be rescaled. block_shape: The block-shape for running the prediction. @@ -501,13 +522,26 @@ def run_unet_prediction_slurm( # does NOT need GPU, FIXME: only run on CPU -def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None: +def run_unet_segmentation_slurm( + output_folder: str, + min_size: int, + center_distance_threshold: float = 0.4, + boundary_distance_threshold: float = 0.5, + fg_threshold: float = 0.5, +) -> None: """Create segmentation from prediction. Args: output_folder: The output folder for storing the segmentation related data. min_size: The minimal size of segmented objects in the output. + center_distance_threshold: The threshold applied to the distance center predictions to derive seeds. + boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds. + By default this is set to 'None', in which case the boundary distances are not used for the seeds. + fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. """ min_size = int(min_size) pmap_out = os.path.join(output_folder, "predictions.zarr") - distance_watershed_implementation(pmap_out, output_folder, min_size=min_size) + distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold, + boundary_distance_threshold=boundary_distance_threshold, + fg_threshold=fg_threshold, + min_size=min_size) diff --git a/scripts/extract_block.py b/scripts/extract_block.py index 6d5ade4..df1e5c5 100644 --- a/scripts/extract_block.py +++ b/scripts/extract_block.py @@ -10,13 +10,14 @@ import zarr import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.file_utils import read_image_data def main( input_path: str, coords: List[int], - output_dir: str = None, - input_key: str = "setup0/timepoint0/s0", + output_dir: Optional[str] = None, + input_key: Optional[str] = None, output_key: Optional[str] = None, resolution: float = 0.38, roi_halo: List[int] = [128, 128, 64], @@ -62,6 +63,7 @@ def main( basename = input_content[0] + resized_suffix else: basename = "".join(input_content[-1].split(".")[:-1]) + image_prefix = basename.split("_")[-1] input_dir = input_path.split(basename)[0] input_dir = os.path.abspath(input_dir) @@ -87,21 +89,17 @@ def main( roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: - s3_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with zarr.open(s3_path, mode="r") as f: - raw = f[input_key][roi] - - else: - with zarr.open(input_path, mode="r") as f: - raw = f[input_key][roi] + data_ = read_image_data(input_path, input_key) + data_roi = data_[roi] if tif: - imageio.imwrite(output_file, raw, compression="zlib") + imageio.imwrite(output_file, data_roi, compression="zlib") else: with zarr.open(output_file, mode="w") as f_out: - f_out.create_dataset(output_key, data=raw, compression="gzip") + f_out.create_dataset(output_key, data=data_roi, compression="gzip") if __name__ == "__main__": @@ -114,7 +112,7 @@ def main( parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate as center of extracted block, json-encoded.") - parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", + parser.add_argument('-k', "--input_key", type=str, default=None, help="Input key for data in input file.") parser.add_argument("--output_key", type=str, default=None, help="Output key for data in output file.") diff --git a/scripts/training/sgn_semi_supervised.py b/scripts/training/sgn_semi_supervised.py index 3a76405..011232c 100644 --- a/scripts/training/sgn_semi_supervised.py +++ b/scripts/training/sgn_semi_supervised.py @@ -48,15 +48,6 @@ def run_training(name): super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths() - print("super_train", len(super_train_img)) - print("super_train", len(super_train_label)) - - print("super_val", len(super_val_img)) - print("super_val", len(super_val_label)) - - print("unsuper",len(unsuper_train)) - print("unsuper",len(unsuper_train)) - mean_teacher_training( name=name, unsupervised_train_paths=unsuper_train,