diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index e7b3222..07e5913 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -249,7 +249,7 @@ def erode_subset( Returns: The dataframe containing elements left after the erosion. """ - print("initial length", len(table)) + print(f"Initial length: {len(table)}") n_neighbors = 100 for i in range(iterations): table = table[table[keyword] < threshold] @@ -406,10 +406,12 @@ def components_sgn( min_cells = 20000 threshold = threshold_erode if threshold_erode is not None else 40 - print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") - - new_subset = erode_subset(table.copy(), iterations=iterations, - threshold=threshold, min_cells=min_cells, keyword=keyword) + if iterations != 0: + print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.") + new_subset = erode_subset(table.copy(), iterations=iterations, + threshold=threshold, min_cells=min_cells, keyword=keyword) + else: + new_subset = table.copy() # create graph from coordinates of eroded subset centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) @@ -486,41 +488,10 @@ def label_components_sgn( table.sort_values("label_id") component_labels = [0 for _ in range(len(table))] + table.loc[:, "component_labels"] = component_labels # be aware of 'label_id' of dataframe starting at 1 for lab, comp in enumerate(components): - for comp_index in comp: - component_labels[comp_index - 1] = lab + 1 - - return component_labels - - -def postprocess_sgn_seg( - table: pd.DataFrame, - min_size: int = 1000, - threshold_erode: Optional[float] = None, - min_component_length: int = 50, - max_edge_distance: float = 30, - iterations_erode: Optional[int] = None, -) -> pd.DataFrame: - """Postprocessing SGN segmentation of cochlea. - - Args: - table: Dataframe of segmentation table. - min_size: Minimal number of pixels for filtering small instances. - threshold_erode: Threshold of column value after erosion step with spatial statistics. - min_component_length: Minimal length for filtering out connected components. - max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. - iterations_erode: Number of iterations for erosion, normally determined automatically. - - Returns: - Dataframe with component labels. - """ - - comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance, iterations_erode=iterations_erode) - - table.loc[:, "component_labels"] = comp_labels + table.loc[table["label_id"].isin(comp), "component_labels"] = lab + 1 return table @@ -583,37 +554,10 @@ def label_components_ihc( length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) component_labels = [0 for _ in range(len(table))] + table.loc[:, "component_labels"] = component_labels # be aware of 'label_id' of dataframe starting at 1 for lab, comp in enumerate(components): - for comp_index in comp: - component_labels[comp_index - 1] = lab + 1 - - return component_labels - - -def postprocess_ihc_seg( - table: pd.DataFrame, - min_size: int = 1000, - min_component_length: int = 50, - max_edge_distance: float = 30, -) -> pd.DataFrame: - """Postprocessing IHC segmentation of cochlea. - - Args: - table: Dataframe of segmentation table. - min_size: Minimal number of pixels for filtering small instances. - min_component_length: Minimal length for filtering out connected components. - max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. - - Returns: - Dataframe with component labels. - """ - - comp_labels = label_components_ihc(table, min_size=min_size, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance) - - table.loc[:, "component_labels"] = comp_labels + table.loc[table["label_id"].isin(comp), "component_labels"] = lab + 1 return table diff --git a/reproducibility/README.md b/reproducibility/README.md index c8f9b1c..4c87854 100644 --- a/reproducibility/README.md +++ b/reproducibility/README.md @@ -10,14 +10,14 @@ The extraction of blocks from a 3D volume is required for the creation of annota Usage: ``` python repro_block_extraction.py --input --output -``` +``` -## Post-processing of SGN segmentation +## Labeling components in the segmentation -The post-processing of the SGN segmentation may involve the erosion of the segmentation to exclude artifacts, the variation of the minimal number of nodes within a component, or the minimal distance between nodes to consider them the same component. +The labeling of the SGN segmentation may involve the erosion of the segmentation to exclude artifacts, the variation of the minimal number of nodes within a component, or the minimal distance between nodes to consider them the same component. Usage: ``` -python repro_postprocess_sgn_v1.py --input --output -``` +python repro_label_components.py --input --output +``` diff --git a/reproducibility/block_extraction/ChReef_MLR144R.json b/reproducibility/block_extraction/ChReef_MLR144R.json index 76c6716..77b8dfb 100644 --- a/reproducibility/block_extraction/ChReef_MLR144R.json +++ b/reproducibility/block_extraction/ChReef_MLR144R.json @@ -6,6 +6,9 @@ "GFP", "SGN_v2" ], + "segmentation_channel": "SGN_v2", + "type": "sgn", + "n_blocks": 6, "crop_centers": [ [ 1329, diff --git a/reproducibility/block_extraction/ChReef_MLR145R.json b/reproducibility/block_extraction/ChReef_MLR145R.json index 779701a..885d1ef 100644 --- a/reproducibility/block_extraction/ChReef_MLR145R.json +++ b/reproducibility/block_extraction/ChReef_MLR145R.json @@ -6,6 +6,9 @@ "GFP", "SGN_v2" ], + "segmentation_channel": "SGN_v2", + "type": "sgn", + "n_blocks": 6, "crop_centers": [ [ 789, diff --git a/reproducibility/block_extraction/ChReef_MLR155L.json b/reproducibility/block_extraction/ChReef_MLR155L.json new file mode 100644 index 0000000..6ef35e4 --- /dev/null +++ b/reproducibility/block_extraction/ChReef_MLR155L.json @@ -0,0 +1,53 @@ +[ + { + "cochlea": "M_LR_000155_L", + "image_channel": [ + "PV", + "GFP", + "SGN_v2" + ], + "segmentation_channel": "SGN_v2", + "type": "sgn", + "n_blocks": 6, + "crop_centers": [ + [ + 1725, + 713, + 482 + ], + [ + 1395, + 810, + 389 + ], + [ + 1070, + 681, + 454 + ], + [ + 1057, + 677, + 785 + ], + [ + 1121, + 1002, + 769 + ], + [ + 803, + 1057, + 690 + ] + ], + "halo_size": [ + 256, + 256, + 50 + ], + "component_list": [ + 1 + ] + } +] \ No newline at end of file diff --git a/reproducibility/block_extraction/ChReef_MLR155R.json b/reproducibility/block_extraction/ChReef_MLR155R.json index 38cfc9d..8d7facc 100644 --- a/reproducibility/block_extraction/ChReef_MLR155R.json +++ b/reproducibility/block_extraction/ChReef_MLR155R.json @@ -6,6 +6,9 @@ "GFP", "SGN_v2" ], + "segmentation_channel": "SGN_v2", + "type": "sgn", + "n_blocks": 6, "crop_centers": [ [ 1634, diff --git a/reproducibility/label_components/IHC_v4c_fig2.json b/reproducibility/label_components/IHC_v4c_fig2.json new file mode 100644 index 0000000..8d40000 --- /dev/null +++ b/reproducibility/label_components/IHC_v4c_fig2.json @@ -0,0 +1,26 @@ +[ + { + "cochlea": "M_LR_000226_L", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v4c" + }, + { + "cochlea": "M_LR_000226_R", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v4c" + }, + { + "cochlea": "M_LR_000227_L", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v4c" + }, + { + "cochlea": "M_LR_000227_R", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v4c" + } +] diff --git a/reproducibility/postprocess_sgn/SGN_v1_postprocess.json b/reproducibility/label_components/SGN_v1_postprocess.json similarity index 100% rename from reproducibility/postprocess_sgn/SGN_v1_postprocess.json rename to reproducibility/label_components/SGN_v1_postprocess.json diff --git a/reproducibility/label_components/SGN_v2_ChReef.json b/reproducibility/label_components/SGN_v2_ChReef.json new file mode 100644 index 0000000..9c80356 --- /dev/null +++ b/reproducibility/label_components/SGN_v2_ChReef.json @@ -0,0 +1,77 @@ +[ + { + "cochlea": "M_LR_000143_L", + "image_channel": "PV", + "cell_type": "sgn", + "max_edge_distance": 70, + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000144_L", + "image_channel": "PV", + "cell_type": "sgn", + "max_edge_distance": 50, + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000145_L", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000153_L", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000155_L", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000189_L", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000143_R", + "image_channel": "PV", + "cell_type": "sgn", + "max_edge_distance": 50, + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000144_R", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000145_R", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000153_R", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000155_R", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + }, + { + "cochlea": "M_LR_000189_R", + "image_channel": "PV", + "cell_type": "sgn", + "unet_version": "v2" + } +] diff --git a/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py b/reproducibility/label_components/repro_label_components.py similarity index 52% rename from reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py rename to reproducibility/label_components/repro_label_components.py index 852ff83..d5054a7 100644 --- a/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py +++ b/reproducibility/label_components/repro_label_components.py @@ -5,10 +5,10 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path -from flamingo_tools.segmentation.postprocessing import postprocess_sgn_seg +from flamingo_tools.segmentation.postprocessing import label_components_sgn, label_components_ihc -def repro_postprocess_sgn_v1( +def repro_label_components( ddict: dict, output_dir: str, s3_credentials: Optional[str] = None, @@ -20,47 +20,62 @@ def repro_postprocess_sgn_v1( default_min_length = 50 default_max_edge_distance = 30 default_iterations_erode = None + default_cell_type = "sgn" + default_component_list = [1] with open(ddict, 'r') as myfile: data = myfile.read() param_dicts = json.loads(data) - for dic in param_dicts[2:4]: + for dic in param_dicts: cochlea = dic["cochlea"] - print(f"Creating components for {cochlea}.") - suffix = dic["suffix"] - tsv_path, fs = get_s3_path(dic["s3_path"], bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, 'r') as f: - table = pd.read_csv(f, sep="\t") + print(f"Labeling components for {cochlea}.") + unet_version = dic["unet_version"] threshold_erode = dic["threshold_erode"] if "threshold_erode" in dic else default_threshold_erode min_component_length = dic["min_component_length"] if "min_component_length" in dic else default_min_length max_edge_distance = dic["max_edge_distance"] if "max_edge_distance" in dic else default_max_edge_distance iterations_erode = dic["iterations_erode"] if "iterations_erode" in dic else default_iterations_erode + cell_type = dic["cell_type"] if "cell_type" in dic else default_cell_type + component_list = dic["component_list"] if "component_list" in dic else default_component_list - print("threshold_erode", threshold_erode) - print("min_component_length", min_component_length) - print("max_edge", max_edge_distance) - print("iterations_erode", iterations_erode) - - tsv_table = postprocess_sgn_seg(table, min_size=min_size, - threshold_erode=threshold_erode, - min_component_length=min_component_length, - max_edge_distance=max_edge_distance, - iterations_erode=iterations_erode) - - largest_comp = len(tsv_table[tsv_table["component_labels"] == 1]) - print(f"Largest component has {largest_comp} SGNs.") + table_name = f"{cell_type.upper()}_{unet_version}" + s3_path = os.path.join(f"{cochlea}", "tables", table_name, "default.tsv") + tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + with fs.open(tsv_path, 'r') as f: + table = pd.read_csv(f, sep="\t") - out_path = os.path.join(output_dir, "".join([cochlea, suffix, ".tsv"])) + if cell_type == "sgn": + tsv_table = label_components_sgn(table, min_size=min_size, + threshold_erode=threshold_erode, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance, + iterations_erode=iterations_erode) + elif cell_type == "ihc": + tsv_table = label_components_ihc(table, min_size=min_size, + min_component_length=min_component_length, + max_edge_distance=max_edge_distance) + else: + raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.") + + largest_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)]) + print(f"The segmentation features {len(tsv_table)} {cell_type.upper()}s.") + if component_list == [1]: + print(f"Largest component has {largest_comp} {cell_type.upper()}s.") + else: + print(f"Custom component(s) have {largest_comp} {cell_type.upper()}s.") + + cochlea_str = "-".join(cochlea.split("_")) + table_str = "-".join(table_name.split("_")) + out_path = os.path.join(output_dir, "_".join([cochlea_str, f"{table_str}.tsv"])) tsv_table.to_csv(out_path, sep="\t", index=False) def main(): parser = argparse.ArgumentParser( - description="Script to extract region of interest (ROI) block around center coordinate.") + description="Script to label segmentation using a segmentation table and graph connected components.") parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") @@ -75,7 +90,7 @@ def main(): args = parser.parse_args() - repro_postprocess_sgn_v1( + repro_label_components( args.input, args.output, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, ) diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 0cfa7fc..74a255f 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -7,7 +7,7 @@ import flamingo_tools.s3_utils as s3_utils from flamingo_tools.segmentation import filter_segmentation from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius -from flamingo_tools.segmentation.postprocessing import postprocess_sgn_seg +from flamingo_tools.segmentation.postprocessing import label_components_sgn # TODO needs updates @@ -124,7 +124,7 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold): tsv_table = pd.read_csv(f, sep="\t") if seg_path is None: - post_table = postprocess_sgn_seg( + post_table = label_components_sgn( tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold, min_component_length=args.min_component_length, max_edge_distance=args.max_edge_dist, iterations_erode=args.iterations_erode,