diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py index dd03c57..8ea668f 100644 --- a/flamingo_tools/measurements.py +++ b/flamingo_tools/measurements.py @@ -2,12 +2,16 @@ import os from concurrent import futures from functools import partial -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np import pandas as pd import trimesh +from elf.io import open_file +from elf.wrapper.resized_volume import ResizedVolume +from nifty.tools import blocking from skimage.measure import marching_cubes, regionprops_table +from scipy.ndimage import binary_dilation from tqdm import tqdm from .file_utils import read_image_data @@ -29,9 +33,14 @@ def _measure_volume_and_surface(mask, resolution): return volume, surface -def _get_bounding_box_and_center(table, seg_id, resolution, shape): +def _get_bounding_box_and_center(table, seg_id, resolution, shape, dilation): row = table[table.label_id == seg_id] + if dilation is not None and dilation > 0: + bb_extension = dilation + 1 + else: + bb_extension = 1 + bb_min = np.array([ row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item() ]).astype("float32") / resolution @@ -43,7 +52,7 @@ def _get_bounding_box_and_center(table, seg_id, resolution, shape): bb_max = np.round(bb_max, 0).astype("int32") bb = tuple( - slice(max(bmin - 1, 0), min(bmax + 1, sh)) + slice(max(bmin - bb_extension, 0), min(bmax + bb_extension, sh)) for bmin, bmax, sh in zip(bb_min, bb_max, shape) ) @@ -115,13 +124,15 @@ def _normalize_background(measures, image, mask, center, radius, norm, median_on def _default_object_features( seg_id, table, image, segmentation, resolution, - foreground_mask=None, background_radius=None, norm=np.divide, median_only=False, + background_mask=None, background_radius=None, norm=np.divide, median_only=False, dilation=None ): - bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape) + bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation) local_image = image[bb] mask = segmentation[bb] == seg_id assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." + if dilation is not None and dilation > 0: + mask = binary_dilation(mask, iterations=dilation) masked_intensity = local_image[mask] # Do the base intensity measurements. @@ -141,7 +152,7 @@ def _default_object_features( # The resolution is given in micrometer per pixel. # So we have to divide by the resolution to obtain the radius in pixel. radius_in_pixel = background_radius / resolution - measures = _normalize_background(measures, image, foreground_mask, center, radius_in_pixel, norm, median_only) + measures = _normalize_background(measures, image, background_mask, center, radius_in_pixel, norm, median_only) # Do the volume and surface measurement. if not median_only: @@ -151,13 +162,15 @@ def _default_object_features( return measures -def _regionprops_features(seg_id, table, image, segmentation, resolution, foreground_mask=None): - bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape) +def _regionprops_features(seg_id, table, image, segmentation, resolution, background_mask=None, dilation=None): + bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape, dilation) local_image = image[bb] local_segmentation = segmentation[bb] mask = local_segmentation == seg_id assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." + if dilation is not None and dilation > 0: + mask = binary_dilation(mask, iterations=dilation) local_segmentation[~mask] = 0 features = regionprops_table( @@ -196,7 +209,6 @@ def _regionprops_features(seg_id, table, image, segmentation, resolution, foregr """ -# TODO integrate segmentation post-processing, see `_extend_sgns_simple` in `gfp_annotation.py` def compute_object_measures_impl( image: np.typing.ArrayLike, segmentation: np.typing.ArrayLike, @@ -204,8 +216,9 @@ def compute_object_measures_impl( resolution: float = 0.38, table: Optional[pd.DataFrame] = None, feature_set: str = "default", - foreground_mask: Optional[np.typing.ArrayLike] = None, + background_mask: Optional[np.typing.ArrayLike] = None, median_only: bool = False, + dilation: Optional[int] = None, ) -> pd.DataFrame: """Compute simple intensity and morphology measures for each segmented cell in a segmentation. @@ -218,8 +231,10 @@ def compute_object_measures_impl( resolution: The resolution / voxel size of the data. table: The segmentation table. Will be computed on the fly if it is not given. feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details. - foreground_mask: An optional mask indicating the area to use for computing background correction values. + background_mask: An optional mask indicating the area to use for computing background correction values. median_only: Whether to only compute the median intensity. + dilation: Value for dilating the segmentation before computing measurements. + By default no dilation is applied. Returns: The table with per object measurements. @@ -235,8 +250,9 @@ def compute_object_measures_impl( image=image, segmentation=segmentation, resolution=resolution, - foreground_mask=foreground_mask, + background_mask=background_mask, median_only=median_only, + dilation=dilation, ) seg_ids = table.label_id.values @@ -246,6 +262,7 @@ def compute_object_measures_impl( # For debugging. # measure_function(seg_ids[0]) + # breakpoint() with futures.ThreadPoolExecutor(n_threads) as pool: measures = list(tqdm( @@ -272,6 +289,9 @@ def compute_object_measures( feature_set: str = "default", s3_flag: bool = False, component_list: List[int] = [], + dilation: Optional[int] = None, + median_only: bool = False, + background_mask: Optional[np.typing.ArrayLike] = None, ) -> None: """Compute simple intensity and morphology measures for each segmented cell in a segmentation. @@ -291,6 +311,12 @@ def compute_object_measures( resolution: The resolution / voxel size of the data. force: Whether to overwrite an existing output table. feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details. + s3_flag: + component_list: + median_only: Whether to only compute the median intensity. + dilation: Value for dilating the segmentation before computing measurements. + By default no dilation is applied. + background_mask: An optional mask indicating the area to use for computing background correction values. """ if os.path.exists(output_table_path) and not force: return @@ -315,5 +341,92 @@ def compute_object_measures( measures = compute_object_measures_impl( image, segmentation, n_threads, resolution, table=table, feature_set=feature_set, + median_only=median_only, dilation=dilation, background_mask=background_mask, ) measures.to_csv(output_table_path, sep="\t", index=False) + + +def compute_sgn_background_mask( + image_path: str, + segmentation_path: str, + image_key: Optional[str] = None, + segmentation_key: Optional[str] = None, + threshold_percentile: float = 35.0, + scale_factor: Tuple[int, int, int] = (16, 16, 16), + n_threads: Optional[int] = None, + cache_path: Optional[str] = None, +) -> np.typing.ArrayLike: + """Compute the background mask for intensity measurements in the SGN segmentation. + + This function computes a mask for determining the background signal in the rosenthal canal. + It is computed by downsampling the image (PV) and segmentation (SGNs) internally, + by thresholding the downsampled image, and by then intersecting this mask with the segmentation. + This results in a mask that is positive for the background signal within the rosenthal canal. + + Args: + image_path: The path to the image data with the PV channel. + segmentation_path: The path to the SGN segmentation. + image_key: Internal path for the image data, for zarr or similar file formats. + segmentation_key: Internal path for the segmentation data, for zarr or similar file formats. + threshold_percentile: The percentile threshold for separating foreground and background in the PV signal. + scale_factor: The scale factor for internally downsampling the mask. + n_threads: The number of threads for parallelizing the computation. + cache_path: Optional path to save the downscaled background mask to zarr. + + Returns: + The mask for determining the background values. + """ + image = read_image_data(image_path, image_key) + segmentation = read_image_data(segmentation_path, segmentation_key) + assert image.shape == segmentation.shape + + if cache_path is not None and os.path.exists(cache_path): + with open_file(cache_path, "r") as f: + if "mask" in f: + low_res_mask = f["mask"][:] + mask = ResizedVolume(low_res_mask, shape=image.shape, order=0) + return mask + + original_shape = image.shape + downsampled_shape = tuple(int(np.round(sh / sf)) for sh, sf in zip(original_shape, scale_factor)) + + low_res_mask = np.zeros(downsampled_shape, dtype="bool") + + # This corresponds to a block shape of 128 x 512 x 512 in the original resolution, + # which roughly corresponds to the size of the blocks we use for the GFP annotation. + chunk_shape = (8, 32, 32) + + blocks = blocking((0, 0, 0), downsampled_shape, chunk_shape) + n_blocks = blocks.numberOfBlocks + + img_resized = ResizedVolume(image, downsampled_shape) + seg_resized = ResizedVolume(segmentation, downsampled_shape, order=0) + + def _compute_block(block_id): + block = blocks.getBlock(block_id) + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) + + img = img_resized[bb] + threshold = np.percentile(img, threshold_percentile) + + this_mask = img > threshold + this_seg = seg_resized[bb] != 0 + this_seg = binary_dilation(this_seg) + this_mask[this_seg] = 0 + + low_res_mask[bb] = this_mask + + n_threads = mp.cpu_count() if n_threads is None else n_threads + randomized_blocks = np.arange(0, n_blocks) + np.random.shuffle(randomized_blocks) + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(_compute_block, randomized_blocks), total=n_blocks, desc="Compute background mask" + )) + + if cache_path is not None: + with open_file(cache_path, "a") as f: + f.create_dataset("mask", data=low_res_mask, chunks=(64, 64, 64)) + + mask = ResizedVolume(low_res_mask, shape=original_shape, order=0) + return mask diff --git a/flamingo_tools/segmentation/chreef_utils.py b/flamingo_tools/segmentation/chreef_utils.py new file mode 100644 index 0000000..26fb8e4 --- /dev/null +++ b/flamingo_tools/segmentation/chreef_utils.py @@ -0,0 +1,164 @@ +import os +import multiprocessing as mp +from concurrent import futures +from typing import List, Tuple + +import numpy as np +import tifffile +from tqdm import tqdm + + +def coord_from_string(center_str): + return tuple([int(c) for c in center_str.split("-")]) + + +def find_annotations(annotation_dir, cochlea) -> dict: + """Create dictionary for analysis of ChReef annotations. + Annotations should have format positive-negative__crop__allNegativeExcluded_thr.tif + + Args: + annotation_dir: Directory containing annotations. + """ + + def extract_center_string(cochlea, name): + # Extract center crop coordinate from file name + crop_suffix = name.split(f"{cochlea}_crop_")[1] + center_str = crop_suffix.split("_")[0] + return center_str + + cochlea_files = [entry.name for entry in os.scandir(annotation_dir) if cochlea in entry.name] + dic = {"cochlea": cochlea} + dic["cochlea_files"] = cochlea_files + center_strings = list(set([extract_center_string(cochlea, name=f) for f in cochlea_files])) + center_strings.sort() + dic["center_strings"] = center_strings + remove_strings = [] + for center_str in center_strings: + files_neg = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "NegativeExcluded"])] + files_pos = [c for c in cochlea_files if all(x in c for x in [cochlea, center_str, "WeakPositive"])] + if len(files_neg) != 1 or len(files_pos) != 1: + print(f"Skipping crop {center_str} for cochlea {cochlea}. " + f"Missing or multiple annotation files in {annotation_dir}.") + remove_strings.append(center_str) + else: + dic[center_str] = {"file_neg": os.path.join(annotation_dir, files_neg[0]), + "file_pos": os.path.join(annotation_dir, files_pos[0])} + for rm_str in remove_strings: + dic["center_strings"].remove(rm_str) + + return dic + + +def get_roi(coord: tuple, roi_halo: tuple, resolution: float = 0.38) -> Tuple[int]: + """Get parameters for loading ROI of segmentation. + + Args: + coord: Center coordinate. + roi_halo: Halo for roi. + resolution: Resolution of array in µm. + + Returns: + region of interest + """ + coords = list(coord) + # reverse dimensions for correct extraction + coords.reverse() + coords = np.array(coords) + coords = coords / resolution + coords = np.round(coords).astype(np.int32) + + roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) + return roi + + +def find_overlapping_masks( + arr_base: np.ndarray, + arr_ref: np.ndarray, + label_id_base: int, + min_overlap: float = 0.5, +) -> List[int]: + """Find overlapping masks between base array and reference array. + + Args: + arr_base: Base array. + arr_ref: Reference array. + label_id_base: ID of segmentation to check for overlap. + min_overlap: Minimal overlap to consider segmentation ID as matching. + + Returns: + Matching IDs of reference array. + """ + arr_base_labeled = arr_base == label_id_base + + # iterate through segmentation ids in reference mask + ref_ids = list(np.unique(arr_ref)[1:]) + + def check_overlap(ref_id): + # check overlap of reference ID and base + arr_ref_instance = arr_ref == ref_id + + intersection = np.logical_and(arr_ref_instance, arr_base_labeled) + overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance) + if overlap_ratio >= min_overlap: + return ref_id + else: + return None + + n_threads = min(16, mp.cpu_count()) + print(f"Finding overlapping masks with {n_threads} Threads.") + with futures.ThreadPoolExecutor(n_threads) as pool: + results = list(tqdm(pool.map(check_overlap, ref_ids), total=len(ref_ids))) + + matching_ids = {r for r in results if r is not None} + return matching_ids + + +def find_inbetween_ids( + arr_negexc: np.typing.ArrayLike, + arr_allweak: np.typing.ArrayLike, + roi_seg: np.typing.ArrayLike, +) -> List[int]: + """Identify list of segmentation IDs inbetween thresholds. + + Args: + arr_negexc: Array with all negatives excluded. + arr_allweak: Array with all weak positives. + roi_sgn: Region of interest of segmentation. + """ + # negative annotation == 1, positive annotation == 2 + negexc_negatives = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=1) + allweak_positives = find_overlapping_masks(arr_allweak, roi_seg, label_id_base=2) + inbetween_ids = [int(i) for i in set(negexc_negatives).intersection(set(allweak_positives))] + return inbetween_ids + + +def get_median_intensity(file_negexc, file_allweak, center, data_seg, table): + arr_negexc = tifffile.imread(file_negexc) + arr_allweak = tifffile.imread(file_allweak) + + roi_halo = tuple([r // 2 for r in arr_negexc.shape]) + roi = get_roi(center, roi_halo) + + roi_seg = data_seg[roi] + inbetween_ids = find_inbetween_ids(arr_negexc, arr_allweak, roi_seg) + subset = table[table["label_id"].isin(inbetween_ids)] + intensities = list(subset["median"]) + return np.median(list(intensities)) + + +def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure): + """Find median intensities in blocks and assign them to center positions of cropped block. + """ + annotation_dic = find_annotations(annotation_dir, cochlea) + # center_keys = [key for key in annotation_dic["center_strings"] if key in annotation_dic.keys()] + + for center_str in annotation_dic["center_strings"]: + center_coord = coord_from_string(center_str) + print(f"Getting mean intensities for {center_coord}.") + file_pos = annotation_dic[center_str]["file_pos"] + file_neg = annotation_dic[center_str]["file_neg"] + median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg, table_measure) + + annotation_dic[center_str]["median_intensity"] = median_intensity + + return annotation_dic diff --git a/reproducibility/object_measures/2025-07-SGN_PV-GFP.json b/reproducibility/object_measures/2025-07-SGN_PV-GFP.json new file mode 100644 index 0000000..b203af8 --- /dev/null +++ b/reproducibility/object_measures/2025-07-SGN_PV-GFP.json @@ -0,0 +1,148 @@ +[ + { + "cochlea": "M_LR_000143_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000144_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000145_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000153_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1, + 2, + 3 + ] + }, + { + "cochlea": "M_LR_000155_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000189_L", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000143_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000144_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000145_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000153_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000155_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + }, + { + "cochlea": "M_LR_000189_R", + "image_channel": [ + "PV", + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR143L.json b/reproducibility/object_measures/ChReef_MLR143L.json new file mode 100644 index 0000000..aba84d0 --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR143L.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000143_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR143R.json b/reproducibility/object_measures/ChReef_MLR143R.json new file mode 100644 index 0000000..96892a9 --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR143R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000143_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] \ No newline at end of file diff --git a/reproducibility/object_measures/ChReef_MLR144L.json b/reproducibility/object_measures/ChReef_MLR144L.json new file mode 100644 index 0000000..165f6ea --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR144L.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000144_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR144R.json b/reproducibility/object_measures/ChReef_MLR144R.json new file mode 100644 index 0000000..2d98feb --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR144R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000144_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR145L.json b/reproducibility/object_measures/ChReef_MLR145L.json new file mode 100644 index 0000000..f4099d3 --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR145L.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000145_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR145R.json b/reproducibility/object_measures/ChReef_MLR145R.json new file mode 100644 index 0000000..36b9398 --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR145R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000145_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR153L.json b/reproducibility/object_measures/ChReef_MLR153L.json new file mode 100644 index 0000000..09ae6cd --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR153L.json @@ -0,0 +1,15 @@ +[ + { + "cochlea": "M_LR_000153_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1, + 2, + 3 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR153R.json b/reproducibility/object_measures/ChReef_MLR153R.json new file mode 100644 index 0000000..98a5bbb --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR153R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000153_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR155L.json b/reproducibility/object_measures/ChReef_MLR155L.json new file mode 100644 index 0000000..5dbeb24 --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR155L.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000155_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR155R.json b/reproducibility/object_measures/ChReef_MLR155R.json new file mode 100644 index 0000000..c71c1ad --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR155R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000155_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR189L.json b/reproducibility/object_measures/ChReef_MLR189L.json new file mode 100644 index 0000000..702a9ab --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR189L.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000189_L", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/ChReef_MLR189R.json b/reproducibility/object_measures/ChReef_MLR189R.json new file mode 100644 index 0000000..471215b --- /dev/null +++ b/reproducibility/object_measures/ChReef_MLR189R.json @@ -0,0 +1,13 @@ +[ + { + "cochlea": "M_LR_000189_R", + "image_channel": [ + "GFP" + ], + "segmentation_channel": "SGN_v2", + "background_mask": "yes", + "component_list": [ + 1 + ] + } +] diff --git a/reproducibility/object_measures/process_all_object_measures.py b/reproducibility/object_measures/process_all_object_measures.py new file mode 100644 index 0000000..ff05b37 --- /dev/null +++ b/reproducibility/object_measures/process_all_object_measures.py @@ -0,0 +1,74 @@ +import json +import os +import subprocess +import zarr + +import flamingo_tools.s3_utils as s3_utils + +OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements/" # noqa +JSON_ROOT = "/user/pape41/u12086/Work/my_projects/flamingo-tools/reproducibility/object_measures" +COCHLEAE = [ + "M_LR_000143_L", + "M_LR_000144_L", + "M_LR_000145_L", + "M_LR_000153_L", + "M_LR_000155_L", + "M_LR_000189_L", + "M_LR_000143_R", + "M_LR_000144_R", + "M_LR_000145_R", + "M_LR_000153_R", + "M_LR_000155_R", + "M_LR_000189_R", +] + + +def process_cochlea(cochlea, start_slurm): + short_name = cochlea.replace("_", "").replace("0", "") + + # Check if this cochlea has been processed already. + output_name = cochlea.replace("_", "-") + output_path = os.path.join(OUTPUT_ROOT, f"{output_name}_GFP_SGN-v2_object-measures.tsv") + if os.path.exists(output_path): + print(cochlea, "has been processed already.") + return + + # Check if the raw data for this cochlea is accessible. + img_name = f"{cochlea}/images/ome-zarr/GFP.ome.zarr" + img_path, _ = s3_utils.get_s3_path(img_name) + try: + zarr.open(img_path, mode="r") + except Exception: + print("The data for", cochlea, "at", img_name, "does not exist.") + return + + # Then generate the json file if it does not yet exist. + template_path = os.path.join(JSON_ROOT, "ChReef_MLR143L.json") + with open(template_path, "r") as f: + json_template = json.load(f) + + json_path = os.path.join(JSON_ROOT, f"ChReef_{short_name}.json") + if not os.path.exists(json_path): + print("Write json to", json_path) + # TODO: We may need to replace the component list for some. + json_template[0]["cochlea"] = cochlea + with open(json_path, "w") as f: + json.dump(json_template, f, indent=4) + + print(cochlea, "is not yet processed") + # Then start the slurm job. + if not start_slurm: + return + + print("Submit slurm job for", cochlea) + subprocess.run(["sbatch", "slurm_template.sbatch", json_path, OUTPUT_ROOT]) + + +def main(): + start_slurm = False + for cochlea in COCHLEAE: + process_cochlea(cochlea, start_slurm) + + +if __name__ == "__main__": + main() diff --git a/reproducibility/object_measures/repro_object_measures.py b/reproducibility/object_measures/repro_object_measures.py index a1b2d49..51c3559 100644 --- a/reproducibility/object_measures/repro_object_measures.py +++ b/reproducibility/object_measures/repro_object_measures.py @@ -1,21 +1,25 @@ import argparse import json import os +from multiprocessing import cpu_count from typing import Optional import flamingo_tools.s3_utils as s3_utils -from flamingo_tools.measurements import compute_object_measures +from flamingo_tools.measurements import compute_object_measures, compute_sgn_background_mask def repro_object_measures( json_file: str, output_dir: str, + force_overwrite: bool = False, s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, ): s3_flag = True input_key = "s0" + default_component_list = [1] + default_bg_mask = None with open(json_file, 'r') as myfile: data = myfile.read() @@ -25,7 +29,8 @@ def repro_object_measures( cochlea = dic["cochlea"] image_channels = dic["image_channel"] if isinstance(dic["image_channel"], list) else [dic["image_channel"]] seg_channel = dic["segmentation_channel"] - component_list = dic["component_list"] + component_list = dic["component_list"] if "component_list" in dic else default_component_list + bg_mask = dic["background_mask"] if "background_mask" in dic else default_bg_mask print(f"Processing cochlea {cochlea}") for img_channel in image_channels: @@ -45,15 +50,45 @@ def repro_object_measures( seg_path, fs = s3_utils.get_s3_path(seg_s3, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - compute_object_measures( - image_path=img_path, - segmentation_path=seg_path, - segmentation_table_path=seg_table_s3, - output_table_path=output_table_path, - image_key=input_key, - segmentation_key=input_key, - s3_flag=s3_flag, - component_list=component_list) + n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count())) + if os.path.isfile(output_table_path) and not force_overwrite: + print(f"Skipping creation of {output_table_path}. File already exists.") + + else: + if bg_mask is None: + feature_set = "default" + dilation = None + median_only = False + else: + print("Using background mask for calculating object measures.") + feature_set = "default_background_subtract" + dilation = 4 + median_only = True + mask_cache_path = os.path.join(output_dir, f"{cochlea_str}_{img_str}_{seg_str}_bg-mask.zarr") + bg_mask = compute_sgn_background_mask( + image_path=img_path, + segmentation_path=seg_path, + image_key=input_key, + segmentation_key=input_key, + n_threads=n_threads, + cache_path=mask_cache_path, + ) + + compute_object_measures( + image_path=img_path, + segmentation_path=seg_path, + segmentation_table_path=seg_table_s3, + output_table_path=output_table_path, + image_key=input_key, + segmentation_key=input_key, + feature_set=feature_set, + s3_flag=s3_flag, + component_list=component_list, + dilation=dilation, + median_only=median_only, + background_mask=bg_mask, + n_threads=n_threads, + ) def main(): @@ -63,6 +98,8 @@ def main(): parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") diff --git a/reproducibility/object_measures/slurm_template.sbatch b/reproducibility/object_measures/slurm_template.sbatch new file mode 100755 index 0000000..7c67ac0 --- /dev/null +++ b/reproducibility/object_measures/slurm_template.sbatch @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH -t 24:00:00 # estimated time, adapt to your needs +#SBATCH --mail-user=constantin.pape@informatik.uni-goettingen.de # change this to your mailaddress +#SBATCH --mail-type=FAIL # send mail when job begins and ends + +#SBATCH -p standard96s:shared # the partition +#SBATCH -A nim00007 + +#SBATCH -c 32 +#SBATCH --mem 256G + +PYTHON=/scratch-grete/usr/nimcpape/software/mamba/envs/sam/bin/python +SCRIPT=/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/reproducibility/object_measures/repro_object_measures.py + +$PYTHON $SCRIPT --input $1 --output $2 diff --git a/scripts/intensity_annotation/gfp_annotation.py b/scripts/intensity_annotation/gfp_annotation.py index fa29349..b11eba8 100644 --- a/scripts/intensity_annotation/gfp_annotation.py +++ b/scripts/intensity_annotation/gfp_annotation.py @@ -180,7 +180,7 @@ def gfp_annotation(prefix, default_stat="median", background_norm=None, is_otof= assert mask.shape == seg_extended.shape feature_set = "default_background_norm" if background_norm == "division" else "default_background_subtract" statistics = compute_object_measures_impl( - stain1, seg_extended, feature_set=feature_set, foreground_mask=mask, median_only=True + stain1, seg_extended, feature_set=feature_set, background_mask=mask, median_only=True ) # Open the napari viewer. diff --git a/scripts/measurements/evaluate_marker_annotations.py b/scripts/measurements/evaluate_marker_annotations.py new file mode 100644 index 0000000..7e8605d --- /dev/null +++ b/scripts/measurements/evaluate_marker_annotations.py @@ -0,0 +1,175 @@ +import argparse +import os +from typing import List, Optional + +import pandas as pd + +from flamingo_tools.s3_utils import get_s3_path +from flamingo_tools.file_utils import read_image_data +from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations + +MARKER_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ChReef_PV-GFP/2025-07_PV_GFP_SGN" + + +def get_length_fraction_from_center(table, center_str): + """ Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances. + """ + center_coord = tuple([int(c) for c in center_str.split("-")]) + (cx, cy, cz) = center_coord + offset = 20 + subset = table[ + (cx - offset < table["anchor_x"]) & + (table["anchor_x"] < cx + offset) & + (cy - offset < table["anchor_y"]) & + (table["anchor_y"] < cy + offset) & + (cz - offset < table["anchor_z"]) & + (table["anchor_z"] < cz + offset) + ] + length_fraction = list(subset["length_fraction"]) + length_fraction = float(sum(length_fraction) / len(length_fraction)) + return length_fraction + + +def apply_nearest_threshold(intensity_dic, table_seg, table_measurement): + """Apply threshold to nearest segmentation instances. + Crop centers are transformed into the 'length fraction' parameter of the segmentation table. + This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal's canal. + """ + # assign crop centers to length fraction of Rosenthal's canal + lf_intensity = {} + for key in intensity_dic.keys(): + length_fraction = get_length_fraction_from_center(table_seg, key) + intensity_dic[key]["length_fraction"] = length_fraction + lf_intensity[length_fraction] = {"threshold": intensity_dic[key]["median_intensity"]} + + # get limits for checking marker thresholds + lf_intensity = dict(sorted(lf_intensity.items())) + lf_fractions = list(lf_intensity.keys()) + # start of cochlea + lf_limits = [0] + # half distance between block centers + for i in range(len(lf_fractions) - 1): + lf_limits.append((lf_fractions[i] + lf_fractions[i+1]) / 2) + # end of cochlea + lf_limits.append(1) + + marker_labels = [0 for _ in range(len(table_seg))] + table_seg.loc[:, "marker_labels"] = marker_labels + for num, fraction in enumerate(lf_fractions): + subset_seg = table_seg[ + (table_seg["length_fraction"] > lf_limits[num]) & + (table_seg["length_fraction"] < lf_limits[num + 1]) + ] + # assign values based on limits + threshold = lf_intensity[fraction]["threshold"] + label_ids_seg = subset_seg["label_id"] + + subset_measurement = table_measurement[table_measurement["label_id"].isin(label_ids_seg)] + subset_positive = subset_measurement[subset_measurement["median"] >= threshold] + subset_negative = subset_measurement[subset_measurement["median"] < threshold] + label_ids_pos = list(subset_positive["label_id"]) + label_ids_neg = list(subset_negative["label_id"]) + + table_seg.loc[table_seg["label_id"].isin(label_ids_pos), "marker_labels"] = 1 + table_seg.loc[table_seg["label_id"].isin(label_ids_neg), "marker_labels"] = 2 + + return table_seg + + +def evaluate_marker_annotation( + cochleae, + output_dir: str, + annotation_dirs: Optional[List[str]] = None, + seg_name: str = "SGN_v2", + marker_name: str = "GFP", +): + """Evaluate marker annotations of a single or multiple annotators. + Segmentation instances are assigned a positive (1) or negative label (2) + in form of the "marker_label" component of the output segmentation table. + The assignment is based on the median intensity supplied by a measurement table. + Instances not considered for the assignment are labeled as 0. + + Args: + cochleae: List of cochlea + output_dir: Output directory for segmentation table with 'marker_label' in format __.tsv + annotation_dirs: List of directories containing marker annotations by annotator(s). + seg_name: Identifier for segmentation. + marker_name: Identifier for marker stain. + """ + input_key = "s0" + + if annotation_dirs is None: + if "MARKER_DIR" in globals(): + marker_dir = MARKER_DIR + annotation_dirs = [entry.path for entry in os.scandir(marker_dir) + if os.path.isdir(entry) and "Results" in entry.name] + + for cochlea in cochleae: + cochlea_annotations = [a for a in annotation_dirs if len(find_annotations(a, cochlea)["center_strings"]) != 0] + print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.") + + # get segmentation data + input_path = f"{cochlea}/images/ome-zarr/{seg_name}.ome.zarr" + input_path, fs = get_s3_path(input_path) + data_seg = read_image_data(input_path, input_key) + + table_seg_path = f"{cochlea}/tables/{seg_name}/default.tsv" + table_path_s3, fs = get_s3_path(table_seg_path) + with fs.open(table_path_s3, "r") as f: + table_seg = pd.read_csv(f, sep="\t") + + seg_string = "-".join(seg_name.split("_")) + table_measurement_path = f"{cochlea}/tables/{seg_name}/{marker_name}_{seg_string}_object-measures.tsv" + table_path_s3, fs = get_s3_path(table_measurement_path) + with fs.open(table_path_s3, "r") as f: + table_measurement = pd.read_csv(f, sep="\t") + + # find median intensities by averaging all individual annotations for specific crops + annotation_dics = {} + annotated_centers = [] + for annotation_dir in cochlea_annotations: + + annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg, table_measurement) + annotated_centers.extend(annotation_dic["center_strings"]) + annotation_dics[annotation_dir] = annotation_dic + + annotated_centers = list(set(annotated_centers)) + intensity_dic = {} + # loop over all annotated blocks + for annotated_center in annotated_centers: + intensities = [] + # loop over annotated block from single user + for annotator_key in annotation_dics.keys(): + if annotated_center not in annotation_dics[annotator_key]["center_strings"]: + continue + else: + intensities.append(annotation_dics[annotator_key][annotated_center]["median_intensity"]) + intensity_dic[annotated_center] = {"median_intensity": float(sum(intensities) / len(intensities))} + + table_seg = apply_nearest_threshold(intensity_dic, table_seg, table_measurement) + cochlea_str = "-".join(cochlea.split("_")) + out_path = os.path.join(output_dir, f"{cochlea_str}_{marker_name}_{seg_string}.tsv") + table_seg.to_csv(out_path, sep="\t", index=False) + + +def main(): + parser = argparse.ArgumentParser( + description="Assign each segmentation instance a marker based on annotation thresholds.") + + parser.add_argument('-c', "--cochlea", type=str, nargs="+", required=True, + help="Cochlea(e) to process.") + parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") + + parser.add_argument('-a', '--annotation_dirs', type=str, nargs="+", default=None, + help="Directories containing marker annotations.") + + args = parser.parse_args() + + evaluate_marker_annotation( + args.cochlea, args.output, args.annotation_dirs, + ) + + +if __name__ == "__main__": + + main() diff --git a/test/test_measurements.py b/test/test_measurements.py index 7ab67e3..6540e0d 100644 --- a/test/test_measurements.py +++ b/test/test_measurements.py @@ -55,6 +55,31 @@ def test_compute_object_measures(self): ]: self.assertTrue(np.allclose(table[col].values, expected_measures[col_exp].values)) + # Test the object measurement functionality as it's used for the gfp intensity measurements: + # - computing only median intensity + # - with a dilation of 4 + # - with background subtraction + # - and using a mask for the background subtraction + def test_compute_object_measures_gfp(self): + from flamingo_tools.measurements import compute_object_measures, compute_sgn_background_mask + + dilation = 4 + background_mask = compute_sgn_background_mask(self.image_path, self.seg_path, scale_factor=(2, 4, 4)) + + output_path = os.path.join(self.folder, "measurements.tsv") + compute_object_measures( + self.image_path, self.seg_path, self.table_path, output_path, n_threads=1, + dilation=dilation, median_only=True, feature_set="default_background_subtract", + background_mask=background_mask, + ) + self.assertTrue(os.path.exists(output_path)) + + table = pd.read_csv(output_path, sep="\t") + self.assertTrue(len(table) >= 1) + expected_columns = ["label_id", "median"] + for col in expected_columns: + self.assertIn(col, table.columns) + if __name__ == "__main__": unittest.main()