From 66670cd237b3541f549f43c76699e3783007ee7a Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 14 Apr 2025 22:35:27 +0200 Subject: [PATCH 1/7] Refactor code, apply linter, and unify style of docstring --- flamingo_tools/data_conversion.py | 87 +++++---- flamingo_tools/file_utils.py | 87 +++++++++ flamingo_tools/mobie.py | 137 +++++++++++--- flamingo_tools/s3_utils.py | 80 +++++--- flamingo_tools/segmentation/postprocessing.py | 37 ++-- .../segmentation/unet_prediction.py | 173 ++++++++++-------- flamingo_tools/test_data.py | 10 +- scripts/prediction/add_to_mobie.py | 46 ++++- .../run_prediction_distance_unet.py | 27 +-- scripts/resize_wrongly_scaled_cochleas.py | 20 +- scripts/training/train_distance_unet.py | 2 +- 11 files changed, 507 insertions(+), 199 deletions(-) create mode 100644 flamingo_tools/file_utils.py diff --git a/flamingo_tools/data_conversion.py b/flamingo_tools/data_conversion.py index 20e3783..9bc4723 100644 --- a/flamingo_tools/data_conversion.py +++ b/flamingo_tools/data_conversion.py @@ -5,16 +5,17 @@ from glob import glob from pathlib import Path -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Tuple import numpy as np import pybdv -import tifffile from cluster_tools.utils.volume_utils import write_format_metadata from elf.io import open_file from skimage.transform import rescale +from .file_utils import read_tif, read_raw + def _read_resolution_and_unit_flamingo(mdata_path): resolution = None @@ -60,7 +61,27 @@ def _read_start_position_flamingo(path): return start_position -def read_metadata_flamingo(metadata_path, offset=None, parse_affine=False): +def read_metadata_flamingo( + metadata_path: str, + offset: Optional[np.ndarray] = None, + parse_affine: bool = False +) -> Tuple[List[float], str, List[float]]: + """Read acquisition metadata from a flamingo metadata file. + + This will read the resolution, the physical unit, and optionally the + voxel grid transformation from the metadata file. The voxel grid transformation + places tile at their correct tile position. + + Args: + metadata_path: The path to the metadata file. + offset: The spatial offset of this data. + parse_affine: Whether to read the affine transformation from the metadata. + + Returns: + The resolution / voxel size of the data. + The physical unit of the voxel size. + The affine voxel grid transformation of the data. + """ resolution, unit = None, None resolution, unit = _read_resolution_and_unit_flamingo(metadata_path) @@ -109,7 +130,7 @@ def _pos_to_trafo(pos): # TODO derive the scale factors from the shape rather than hard-coding it to 5 levels -def derive_scale_factors(shape): +def _derive_scale_factors(shape): scale_factors = [[2, 2, 2]] * 5 return scale_factors @@ -147,11 +168,25 @@ def _to_ome_zarr(data, out_path, scale_factors, timepoint, setup_id, attributes, ) -def flamingo_filename_parser(file_path, name_mapping): +def flamingo_filename_parser(file_path: str, name_mapping: Optional[Dict]) -> Tuple[int, Dict[str, str], str]: + """Parse the name of flamingo output files. + + This maps the filenames to the corresponding timepoint, the BigStitcher + compatible attributes, and the id (name) of the attributes. + + Args: + file_path: The path to the flamingo data. + name_mapping: Optional mapping of parsed attributes to their actual names. + + Returns: + The timepoint of this data. + The dictionary mapping attribute names to their values. + The normalized attribute names. + """ filename = os.path.basename(file_path) # Extract the timepoint. - match = re.search(r'_t(\d+)_', filename) + match = re.search(r"_t(\d+)_", filename) if match: timepoint = int(match.group(1)) else: @@ -163,25 +198,25 @@ def flamingo_filename_parser(file_path, name_mapping): name_mapping = {} # Extract the channel. - match = re.search(r'_C(\d+)_', filename) + match = re.search(r"_C(\d+)_", filename) channel = int(match.group(1)) if match else 0 channel_mapping = name_mapping.get("channel", {}) attributes["channel"] = {"id": channel, "name": channel_mapping.get(channel, str(channel))} # Extract the tile. - match = re.search(r'_R(\d+)_', filename) + match = re.search(r"_R(\d+)_", filename) tile = int(match.group(1)) if match else 0 tile_mapping = name_mapping.get("tile", {}) attributes["tile"] = {"id": tile, "name": tile_mapping.get(tile, str(tile))} # Extract the illumination. - match = re.search(r'_I(\d+)_', filename) + match = re.search(r"_I(\d+)_", filename) illumination = int(match.group(1)) if match else 0 illumination_mapping = name_mapping.get("illumination", {}) attributes["illumination"] = {"id": illumination, "name": illumination_mapping.get(illumination, str(illumination))} # Extract D. TODO what is this? - match = re.search(r'_D(\d+)_', filename) + match = re.search(r"_D(\d+)_", filename) D = int(match.group(1)) if match else 0 D_mapping = name_mapping.get("D", {}) attributes["D"] = {"id": D, "name": D_mapping.get(D, str(D))} @@ -207,35 +242,11 @@ def _write_missing_views(out_path): tree.write(xml_path) -def _parse_shape(metadata_file): - depth, height, width = None, None, None - - with open(metadata_file, "r") as f: - for line in f.readlines(): - line = line.strip().rstrip("\n") - if line.startswith("AOI width"): - width = int(line.split(" ")[-1]) - if line.startswith("AOI height"): - height = int(line.split(" ")[-1]) - if line.startswith("Number of planes saved"): - depth = int(line.split(" ")[-1]) - - assert depth is not None - assert height is not None - assert width is not None - return (depth, height, width) - - def _load_data(file_path, metadata_file): if Path(file_path).suffix == ".raw": - shape = _parse_shape(metadata_file) - data = np.memmap(file_path, mode="r", dtype="uint16", shape=shape) + data = read_raw(file_path, metadata_file) else: - try: - data = tifffile.memmap(file_path, mode="r") - except ValueError: - print(f"Could not memmap the data from {file_path}. Fall back to load it into memory.") - data = tifffile.imread(file_path) + data = read_tif(file_path) return data @@ -360,7 +371,7 @@ def convert_lightsheet_to_bdv( print(f"Converting tp={timepoint}, channel={attributes['channel']}, tile={attributes['tile']}") data = _load_data(file_path, metadata_file) if scale_factors is None: - scale_factors = derive_scale_factors(data.shape) + scale_factors = _derive_scale_factors(data.shape) if convert_to_ome_zarr: _to_ome_zarr(data, out_path, scale_factors, timepoint, setup_id, attributes, unit, resolution) @@ -387,6 +398,8 @@ def convert_lightsheet_to_bdv( def convert_lightsheet_to_bdv_cli(): + """@private + """ import argparse parser = argparse.ArgumentParser( diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py new file mode 100644 index 0000000..fc4e15d --- /dev/null +++ b/flamingo_tools/file_utils.py @@ -0,0 +1,87 @@ +import warnings +from typing import Any, Optional, Union + +import imageio.v3 as imageio +import numpy as np +import tifffile +import zarr +from elf.io import open_file + + +def _parse_shape(metadata_file): + depth, height, width = None, None, None + + with open(metadata_file, "r") as f: + for line in f.readlines(): + line = line.strip().rstrip("\n") + if line.startswith("AOI width"): + width = int(line.split(" ")[-1]) + if line.startswith("AOI height"): + height = int(line.split(" ")[-1]) + if line.startswith("Number of planes saved"): + depth = int(line.split(" ")[-1]) + + assert depth is not None + assert height is not None + assert width is not None + return (depth, height, width) + + +def read_raw(file_path: str, metadata_file: str) -> np.memmap: + """Read a raw file written by the flamingo microscope. + + Args: + file_path: The file path to the raw file. + metadata_file: The file path to the metadata describing the raw file. + The metadata will be used to determine the shape of the data. + + Returns: + The memory-mapped data. + """ + shape = _parse_shape(metadata_file) + return np.memmap(file_path, mode="r", dtype="uint16", shape=shape) + + +def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: + """Read a tif file. + + Tries to memory map the file. If not possible will load the complete file into memory + and raise a warning. + + Args: + file_path: The file path to the tif file. + + Returns: + The memory-mapped data. If not possible to memmap, the data in memory. + """ + try: + x = tifffile.memmap(file_path, "r") + except ValueError: + warnings.warn(f"Cannot memmap the tif file at {file_path}. Fall back to loading it into memory.") + x = imageio.imread(file_path) + return x + + +# TODO: Update the any types: +# The first should be the type of a zarr s3 store, +def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.array_like: + """Read flamingo image data, stored in various formats. + + Args: + input_path: The file path to the data, or a zarr S3 store for data remotely accessed on S3. + The data can be stored as a tif file, or a zarr/n5 container. + Access via S3 is only supported for a zarr container. + input_key: The key (= internal path) for a zarr or n5 container. + Set it to None if the data is stored in a tif file. + + Returns: + The data, loaded either as a numpy mem-map, a numpy array, or a zarr / n5 array. + """ + if input_key is None: + input_ = read_tif(input_path) + elif isinstance(input_path, str): + input_ = open_file(input_path, "r")[input_key] + else: + with zarr.open(input_path, mode="r") as f: + input_ = f[input_key] + return input_ diff --git a/flamingo_tools/mobie.py b/flamingo_tools/mobie.py index ae29e32..9b32a3f 100644 --- a/flamingo_tools/mobie.py +++ b/flamingo_tools/mobie.py @@ -1,12 +1,18 @@ import os +import multiprocessing as mp import tempfile -from typing import Tuple +from typing import Optional, Tuple -from mobie import add_bdv_image, add_segmentation +from elf.io import open_file +from mobie import add_bdv_image, add_image, add_segmentation from mobie.metadata.dataset_metadata import read_dataset_metadata +DEFAULT_RESOLUTION = (0.38, 0.38, 0.38) +DEFAULT_SCALE_FACTORS = [[2, 2, 2]] * 5 +DEFAULT_CHUNKS = (128, 128, 128) +DEFAULT_UNIT = "micrometer" + -# TODO refactor to mobie utils def _source_exists(mobie_project, mobie_dataset, source_name): dataset_folder = os.path.join(mobie_project, mobie_dataset) metadata = read_dataset_metadata(dataset_folder) @@ -14,15 +20,53 @@ def _source_exists(mobie_project, mobie_dataset, source_name): return source_name in sources +def _parse_spatial_args( + resolution, scale_factors, chunks, input_path, input_key +): + if resolution is None: + resolution = DEFAULT_RESOLUTION + if scale_factors is None: + scale_factors = DEFAULT_SCALE_FACTORS + if chunks is None: + if input_path.endswith(".tif"): + chunks = DEFAULT_CHUNKS + else: + with open_file(input_path, "r") as f: + chunks = f[input_key].chunks + return resolution, scale_factors, chunks + + def add_raw_to_mobie( mobie_project: str, mobie_dataset: str, source_name: str, - xml_path: str, + input_path: str, skip_existing: bool = True, + input_key: Optional[str] = None, setup_id: int = 0, -): - """ + resolution: Optional[Tuple[float, float, float]] = None, + scale_factors: Optional[Tuple[Tuple[int, int, int]]] = None, + chunks: Optional[Tuple[int, int, int]] = None, +) -> None: + """Add image data to a MoBIE project. + + The input may either be an xml file in BigDataViewer / BigStitcher format, + a n5 / hdf5 / zarr file, or a tif file. + + Args: + mobie_project: The MoBIE project directory. + mobie_dataset The MoBIE dataset the image data will be added to. + source_name: The name of the data to use in MoBIE. + input_path: The path to the data. + skip_existing: Whether to skip existing dataset. + If this is set to false, then an exception will be thrown if the source already + exists in the MoBIE dataset. + input_key: The key of the input data. This only has to be specified if the input is + a n5 / hdf5 / zarr file. + setup_id: The setup_id that will be added to MoBIE. This is only used if the input data is an xml file. + resolution: The resolution / voxel size of the data. + scale_factors: The factors to use for downsampling the data when creating the multi-level image pyramid. + chunks: The output chunks for writing the data. """ # Check if we have converted this data already. have_source = _source_exists(mobie_project, mobie_dataset, source_name) @@ -33,16 +77,42 @@ def add_raw_to_mobie( elif have_source: raise NotImplementedError + max_jobs = min(16, mp.cpu_count()) with tempfile.TemporaryDirectory() as tmpdir: - add_bdv_image( - xml_path=xml_path, - root=mobie_project, - dataset_name=mobie_dataset, - image_name=source_name, - tmp_folder=tmpdir, - file_format="bdv.n5", - setup_ids=[setup_id], - ) + if input_path.endswith(".xml"): + add_bdv_image( + xml_path=input_path, + root=mobie_project, + dataset_name=mobie_dataset, + image_name=source_name, + tmp_folder=tmpdir, + file_format="bdv.n5", + setup_ids=[setup_id], + ) + else: + use_memmap = False + if input_path.endswith(".tif"): + use_memmap = True + assert input_key is None + else: + input_key = "setup0/timepoint0/s0" if input_key is None else input_key + resolution, scale_factors, chunks = _parse_spatial_args( + resolution, scale_factors, chunks, input_path, input_key + ) + add_image( + input_path=input_path, + input_key=input_key, + root=mobie_project, + dataset_name=mobie_dataset, + image_name=source_name, + resolution=resolution, + scale_factors=scale_factors, + chunks=chunks, + tmp_folder=tmpdir, + use_memmap=use_memmap, + unit=DEFAULT_UNIT, + max_jobs=max_jobs, + ) def add_segmentation_to_mobie( @@ -51,12 +121,26 @@ def add_segmentation_to_mobie( source_name: str, segmentation_path: str, segmentation_key: str, - resolution: Tuple[float, float, float], - unit: str, - scale_factors: Tuple[Tuple[int, int, int]], - chunks: Tuple[int, int, int], + resolution: Optional[Tuple[float, float, float]] = None, + scale_factors: Optional[Tuple[Tuple[int, int, int]]] = None, + chunks: Optional[Tuple[int, int, int]] = None, skip_existing: bool = True, -): +) -> None: + """Add a segmentation to a MoBIE dataset. + + Args: + mobie_project: The MoBIE project directory. + mobie_dataset The MoBIE dataset the segmentation will be added to. + source_name: The name of the data to use in MoBIE. + segmentation_path: The path to the data. + segmentation_key: The key of the data. + resolution: The resolution / voxel size of the data. + scale_factors: The factors to use for downsampling the data when creating the multi-level image pyramid. + chunks: The output chunks for writing the data. + skip_existing: Whether to skip existing dataset. + If this is set to false, then an exception will be thrown if the source already + exists in the MoBIE dataset. + """ # Check if we have converted this data already. have_source = _source_exists(mobie_project, mobie_dataset, source_name) if have_source and skip_existing: @@ -66,12 +150,19 @@ def add_segmentation_to_mobie( elif have_source: raise NotImplementedError + resolution, scale_factors, chunks = _parse_spatial_args( + resolution, scale_factors, chunks, segmentation_path, segmentation_key + ) + + max_jobs = min(16, mp.cpu_count()) with tempfile.TemporaryDirectory() as tmpdir: add_segmentation( input_path=segmentation_path, input_key=segmentation_key, root=mobie_project, dataset_name=mobie_dataset, segmentation_name=source_name, - resolution=resolution, scale_factors=scale_factors, - chunks=chunks, file_format="bdv.n5", - tmp_folder=tmpdir + resolution=resolution, + scale_factors=scale_factors, + chunks=chunks, + tmp_folder=tmpdir, + max_jobs=max_jobs, ) diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index c371af4..92c4e4f 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -1,12 +1,12 @@ +"""This file contains utility functions for processing data located on an S3 storage. +The upload of data to the storage system should be performed with 'rclone'. +""" import os +from typing import Optional, Tuple import s3fs import zarr -""" -This script contains utility functions for processing data located on an S3 storage. -The upload of data to the storage system should be performed with 'rclone'. -""" # Dedicated bucket for cochlea lightsheet project MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" @@ -15,45 +15,65 @@ DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials") -# For MoBIE: -# https://s3.gwdg.de/incucyte-general/lightsheet -def check_s3_credentials(bucket_name, service_endpoint, credential_file): - """ - Check if S3 parameter and credentials were set either as a function input or were exported as environment variables. +def check_s3_credentials( + bucket_name: Optional[str], service_endpoint: Optional[str], credential_file: Optional[str] +) -> Tuple[str, str, str]: + """Check if S3 parameter and credentials were set either as input variables or as environment variables. + + Args: + + Returns: """ if bucket_name is None: - bucket_name = os.getenv('BUCKET_NAME') + bucket_name = os.getenv("BUCKET_NAME") if bucket_name is None: if BUCKET_NAME in globals(): bucket_name = BUCKET_NAME else: - raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") + raise ValueError( + "Provide a bucket name for accessing S3 data.\n" + "Either by using an optional argument or exporting an environment variable:\n" + "--s3_bucket_name \n" + "export BUCKET_NAME=" + ) if service_endpoint is None: - service_endpoint = os.getenv('SERVICE_ENDPOINT') + service_endpoint = os.getenv("SERVICE_ENDPOINT") if service_endpoint is None: if SERVICE_ENDPOINT in globals(): service_endpoint = SERVICE_ENDPOINT else: - raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") + raise ValueError( + "Provide a service endpoint for accessing S3 data.\n" + "Either by using an optional argument or exporting an environment variable:\n" + "--s3_service_endpoint \n" + "export SERVICE_ENDPOINT=") if credential_file is None: - access_key = os.getenv('AWS_ACCESS_KEY_ID') - secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') + access_key = os.getenv("AWS_ACCESS_KEY_ID") + secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") # check for default credentials if no credential_file is provided if access_key is None: if os.path.isfile(DEFAULT_CREDENTIALS): access_key, _ = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS) else: - raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") + raise ValueError( + "Either provide a credential file as an optional argument," + f" have credentials at '{DEFAULT_CREDENTIALS}'," + " or export an access key as an environment variable:\n" + "export AWS_ACCESS_KEY_ID=") if secret_key is None: # check for default credentials if os.path.isfile(DEFAULT_CREDENTIALS): _, secret_key = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS) else: - raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") + raise ValueError( + "Either provide a credential file as an optional argument," + f" have credentials at '{DEFAULT_CREDENTIALS}'," + " or export a secret access key as an environment variable:\n" + "export AWS_SECRET_ACCESS_KEY=") else: # check validity of credential file @@ -61,19 +81,23 @@ def check_s3_credentials(bucket_name, service_endpoint, credential_file): return bucket_name, service_endpoint, credential_file + def get_s3_path( - input_path, - bucket_name=None, service_endpoint=None, - credential_file=None, + input_path: str, + bucket_name: Optional[str] = None, + service_endpoint: Optional[str] = None, + credential_file: Optional[str] = None, + # ) -> Tuple[]: ): + """Get S3 path for a file or folder and file system based on S3 parameters and credentials. """ - Get S3 path for a file or folder and file system based on S3 parameters and credentials. - """ - bucket_name, service_endpoint, credential_file = check_s3_credentials(bucket_name, service_endpoint, credential_file) + bucket_name, service_endpoint, credential_file = check_s3_credentials( + bucket_name, service_endpoint, credential_file + ) fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) - zarr_path=f"{bucket_name}/{input_path}" + zarr_path = f"{bucket_name}/{input_path}" if not fs.exists(zarr_path): print(f"Error: S3 path {zarr_path} does not exist!") @@ -96,13 +120,13 @@ def read_s3_credentials(credential_file): return key, secret -def create_s3_target(url, anon=False, credential_file=None): - """ - Create file system for S3 bucket based on a service endpoint and an optional credential file. +def create_s3_target(url=None, anon=False, credential_file=None): + """Create file system for S3 bucket based on a service endpoint and an optional credential file. + If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. """ - client_kwargs = {"endpoint_url": url} + client_kwargs = {"endpoint_url": SERVICE_ENDPOINT if url is None else url} if credential_file is not None: key, secret = read_s3_credentials(credential_file) fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 1264956..06f896d 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -13,9 +13,9 @@ from elf.io import open_file import nifty.tools as nt + def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True): - """ - Calculate average distance of n nearest neighbors. + """Calculate average distance of n nearest neighbors. :param DataFrame tsv_table: :param int n_neighbors: Number of nearest neighbors @@ -39,14 +39,16 @@ def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True): return distance_avg + def filter_isolated_objects( - segmentation, output_path, tsv_table=None, - distance_threshold=15, neighbor_threshold=5, min_size=1000, - output_key="segmentation_postprocessed", - ): - """ - Postprocessing step to filter isolated objects from a segmentation. - Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them. + segmentation, output_path, tsv_table=None, + distance_threshold=15, neighbor_threshold=5, min_size=1000, + output_key="segmentation_postprocessed", +): + """Postprocessing step to filter isolated objects from a segmentation. + + Instance segmentations are filtered if they have fewer neighbors + than a given threshold in a given distance around them. Additionally, size filtering is possible if a TSV file is supplied. :param dataset segmentation: Dataset containing the segmentation @@ -65,9 +67,9 @@ def filter_isolated_objects( # filter out cells smaller than min_size if min_size is not None: - min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size] - centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids] - label_ids = [int(l) for l in label_ids if l not in min_size_label_ids] + min_size_label_ids = [l for (l, n) in zip(label_ids, n_pixels) if n <= min_size] + centroids = [c for (c, l) in zip(centroids, label_ids) if l not in min_size_label_ids] + label_ids = [int(lid) for lid in label_ids if lid not in min_size_label_ids] coordinates = np.array(centroids) label_ids = np.array(label_ids) @@ -92,8 +94,8 @@ def filter_isolated_objects( filter_ids = label_ids[filter_mask] shape = segmentation.shape - block_shape=(128,128,128) - chunks=(128,128,128) + block_shape = (128, 128, 128) + chunks = (128, 128, 128) blocking = nt.blocking([0] * len(shape), shape, block_shape) @@ -105,8 +107,7 @@ def filter_isolated_objects( ) def filter_chunk(block_id): - """ - Set all points within a chunk to zero if they match filter IDs. + """Set all points within a chunk to zero if they match filter IDs. """ block = blocking.getBlock(block_id) volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) @@ -120,6 +121,8 @@ def filter_chunk(block_id): with futures.ThreadPoolExecutor(n_threads) as filter_pool: list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks)) - seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128)) + seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive( + output_dataset, start_label=1, keep_zeros=True, block_shape=(128, 128, 128) + ) return seg_filtered, n_ids, n_ids_filtered diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 35298bb..3cfb1b0 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -1,19 +1,21 @@ +"""Prediction using distance U-Net. +Parallelization using multiple GPUs is currently only possible by calling functions directly. +Functions for the parallelization end with '_slurm' +and divide the process into preprocessing, prediction, and segmentation. +""" +import json import multiprocessing as mp import os -import sys import warnings from concurrent import futures +from typing import Optional, Tuple -import imageio.v3 as imageio import elf.parallel as parallel import numpy as np import nifty.tools as nt import vigra import torch import z5py -import zarr -import tifffile -import json from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper from elf.wrapper.resized_volume import ResizedVolume @@ -23,15 +25,17 @@ from tqdm import tqdm import flamingo_tools.s3_utils as s3_utils +from flamingl_tools.file_utils import read_image_data -""" -Prediction using distance U-Net. -Parallelization using multiple GPUs is currently only possible by calling functions directly. -Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation. -""" class SelectChannel(SimpleTransformationWrapper): - def __init__(self, volume, channel): + """Wrapper to select a chanel from an array-like dataset object. + + Args: + volume: The array-like input dataset. + channel: The channel that will be selected. + """ + def __init__(self, volume: np.array_like, channel: int): self.channel = channel super().__init__(volume, lambda x: x[self.channel], with_channels=True) @@ -48,7 +52,21 @@ def ndim(self): return self._volume.ndim - 1 -def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None, s3=None): +def prediction_impl( + input_path, + input_key, + output_folder, + model_path, + scale, + block_shape, + halo, + prediction_instances=1, + slurm_task_id=0, + mean=None, + std=None, +): + """@private + """ with warnings.catch_warnings(): warnings.simplefilter("ignore") if os.path.isdir(model_path): @@ -59,21 +77,10 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo mask_path = os.path.join(output_folder, "mask.zarr") image_mask = z5py.File(mask_path, "r")["mask"] - if input_key is None: - try: - input_ = tifffile.memmap(input_path, mode="r") - except ValueError: - print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.") - input_ = imageio.imread(input_path) - elif isinstance(input_path, str): - input_ = open_file(input_path, "r")[input_key] - else: - with zarr.open(input_path, mode="r") as f: - input_ = f[input_key] - - chunks = getattr(input_, "chunks", (64,64,64)) + input_ = read_image_data(input_path, input_key) + chunks = getattr(input_, "chunks", (64, 64, 64)) - if scale is None or scale == 1: + if scale is None or np.isclose(scale, 1): original_shape = None else: original_shape = input_.shape @@ -102,7 +109,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo # Compute the global mean and standard deviation. n_threads = min(16, mp.cpu_count()) mean, std = parallel.mean_and_std( - input_, block_shape=tuple([2* i for i in chunks]), n_threads=n_threads, verbose=True, + input_, block_shape=tuple([2 * i for i in chunks]), n_threads=n_threads, verbose=True, mask=image_mask ) print("Mean and standard deviation computed for the full volume:") @@ -152,7 +159,19 @@ def postprocess(x): return original_shape -def find_mask(input_path, input_key, output_folder, s3=None): +def find_mask(input_path: str, input_key: Optional[str], output_folder: str) -> None: + """Determine the mask for running prediction. + + The mask corresponds to data that contains actual signal and not just noise. + This is determined by checking if the 95th percentile of the intensity + of a local block has a value larger than 200. It may be necesary to choose a + different criterion if the data acquisition changes. + + Args: + input_path: The file path to the image data. + input_key: The key / internal path of the image data. + output_folder: The output folder for storing the mask data. + """ mask_path = os.path.join(output_folder, "mask.zarr") f = z5py.File(mask_path, "a") @@ -160,20 +179,8 @@ def find_mask(input_path, input_key, output_folder, s3=None): if mask_key in f: return - if input_key is None: - try: - raw = tifffile.memmap(input_path, mode="r") - except ValueError: - print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.") - raw = imageio.imread(input_path) - elif isinstance(input_path, str): - fin = open_file(input_path, "r") - raw = fin[input_key] - else: - with zarr.open(input_path, mode="r") as fin: - raw = fin[input_key] - - chunks = getattr(raw, "chunks", (64,64,64)) + raw = read_image_data(input_path, input_key) + chunks = getattr(raw, "chunks", (64, 64, 64)) block_shape = tuple(2 * ch for ch in chunks) blocking = nt.blocking([0, 0, 0], raw.shape, block_shape) @@ -196,6 +203,8 @@ def find_mask_block(block_id): def segmentation_impl(input_path, output_folder, min_size, original_shape=None): + """@private + """ input_ = open_file(input_path, "r")["prediction"] # Limit the number of cores for parallelization. @@ -264,43 +273,56 @@ def write_block(block_id): tp.map(write_block, range(blocking.numberOfBlocks)) -def calc_mean_and_std( - input_path, input_key, output_folder, - s3=None, - ): - """ - Calculate mean and standard deviation of full volume. - Parameters are saved in 'mean_std.json' within the output folder. +def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> None: + """Calculate mean and standard deviation of the input volume. + + The parameters are saved in 'mean_std.json' in the output folder. + + Args: + input_path: The file path to the image data. + input_key: The key / internal path of the image data. + output_folder: The output folder for storing the segmentation related data. """ json_file = os.path.join(output_folder, "mean_std.json") mask_path = os.path.join(output_folder, "mask.zarr") image_mask = z5py.File(mask_path, "r")["mask"] - if input_key is None: - input_ = imageio.imread(input_path) - elif s3 is not None: - with zarr.open(input_path, mode="r") as f: - input_ = f[input_key] - else: - input_ = open_file(input_path, "r")[input_key] + input_ = read_image_data(input_path, input_key) + chunks = getattr(input_, "chunks", (64, 64, 64)) # Compute the global mean and standard deviation. n_threads = min(16, mp.cpu_count()) mean, std = parallel.mean_and_std( - input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True, - mask=image_mask + input_, block_shape=tuple([2 * i for i in chunks]), n_threads=n_threads, verbose=True, mask=image_mask ) - ddict = {"mean":mean, "std":std} + ddict = {"mean": mean, "std": std} with open(json_file, "w") as f: json.dump(ddict, f) def run_unet_prediction( - input_path, input_key, - output_folder, model_path, - min_size, scale=None, - block_shape=None, halo=None, -): + input_path: str, + input_key: Optional[str], + output_folder: str, + model_path: str, + min_size: int, + scale: Optional[float] = None, + block_shape: Optional[Tuple[int, int, int]] = None, + halo: Optional[Tuple[int, int, int]] = None, +) -> None: + """Run prediction and segmentation with a distance U-Net. + + Args: + input_path: The path to the input data. + input_key: The key / internal path of the image data. + output_folder: The output folder for storing the segmentation related data. + model_path: The path to the model to use for segmentation. + min_size: The minimal size of segmented objects in the output. + scale: A factor to rescale the data before prediction. + By default the data will not be rescaled. + block_shape: The block-shape for running the prediction. + halo: The halo (= block overlap) to use for prediction. + """ os.makedirs(output_folder, exist_ok=True) find_mask(input_path, input_key, output_folder) @@ -312,20 +334,25 @@ def run_unet_prediction( pmap_out = os.path.join(output_folder, "predictions.zarr") segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape) -#---Workflow for parallel prediction using slurm--- + +# +# ---Workflow for parallel prediction using slurm--- +# + def run_unet_prediction_preprocess_slurm( - input_path, input_key, output_folder, - s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, + input_path, input_key, output_folder, + s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, ): - """ - Pre-processing for the parallel prediction with U-Net models. + """Pre-processing for the parallel prediction with U-Net models. Masks are stored in mask.zarr in the output folder. The mean and standard deviation are precomputed for later usage during prediction and stored in a JSON file within the output folder as mean_std.json. """ if s3 is not None: - input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + input_path, fs = s3_utils.get_s3_path( + input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials + ) if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): find_mask(input_path, input_key, output_folder, s3=s3) @@ -360,7 +387,9 @@ def run_unet_prediction_slurm( slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") if s3 is not None: - input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + input_path, fs = s3_utils.get_s3_path( + input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials + ) if slurm_task_id is not None: slurm_task_id = int(slurm_task_id) @@ -380,7 +409,7 @@ def run_unet_prediction_slurm( mean = None std = None - original_shape = prediction_impl( + prediction_impl( input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=prediction_instances, slurm_task_id=slurm_task_id, mean=mean, std=std, s3=s3, diff --git a/flamingo_tools/test_data.py b/flamingo_tools/test_data.py index 5e52cab..bca605d 100644 --- a/flamingo_tools/test_data.py +++ b/flamingo_tools/test_data.py @@ -5,7 +5,15 @@ # TODO add metadata -def create_test_data(root, size=256, n_channels=2, n_tiles=4): +def create_test_data(root: str, size: int = 256, n_channels: int = 2, n_tiles: int = 4) -> None: + """Create test data in the flamingo data format. + + Args: + root: Directory for saving the data. + size: The axis length for the data. + n_channels The number of channels to create: + n_tiles: The number of tiles to create. + """ channel_folders = [f"channel{chan_id}" for chan_id in range(n_channels)] file_name_pattern = "volume_R%i_C%i_I0.tif" for chan_id, channel_folder in enumerate(channel_folders): diff --git a/scripts/prediction/add_to_mobie.py b/scripts/prediction/add_to_mobie.py index 4640904..3ebb2a3 100644 --- a/scripts/prediction/add_to_mobie.py +++ b/scripts/prediction/add_to_mobie.py @@ -1 +1,45 @@ -# TODO +import argparse + +from flamingo_tools.mobie import add_raw_to_mobie, add_segmentation_to_mobie +from flamingo_tools.s3_utils import MOBIE_FOLDER + + +# TODO could also refactor this into flamingo_utils.mobie +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_path", required=True) + parser.add_argument("-d", "--dataset", required=True) + parser.add_argument("-t", "--type", required=True) + parser.add_argument("-n", "--name", required=True) + parser.add_argument("-k", "--input_key") + + # TODO add optional arguments: + # - over-ride the mobie folder + # - ??? + + args = parser.parse_args() + project_folder = MOBIE_FOLDER + + if args.type == "image": + add_raw_to_mobie( + mobie_project=project_folder, + mobie_dataset=args.dataset, + source_name=args.name, + input_path=args.input_path, + input_key=args.input_key, + ) + elif args.type == "segmentation": + segmentation_key = "segmentation" if args.input_key is None else args.input_key + add_segmentation_to_mobie( + mobie_project=project_folder, + mobie_dataset=args.dataset, + source_name=args.name, + segmentation_path=args.input_path, + segmentation_key=segmentation_key, + ) + else: + raise ValueError(f"Invalid type for a mobie project: {args.name}.") + + +if __name__ == "__main__": + main() diff --git a/scripts/prediction/run_prediction_distance_unet.py b/scripts/prediction/run_prediction_distance_unet.py index fe30ae2..edca013 100644 --- a/scripts/prediction/run_prediction_distance_unet.py +++ b/scripts/prediction/run_prediction_distance_unet.py @@ -1,16 +1,14 @@ +"""Prediction using distance U-Net. +Parallelization using multiple GPUs is currently only possible +by calling functions located in segmentation/unet_prediction.py directly. +Functions for the parallelization end with '_slurm' and divide the process into preprocessing, +prediction, and segmentation. +""" import argparse -import sys import torch import z5py -sys.path.append("../..") - -""" -Prediction using distance U-Net. -Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly. -Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation. -""" def main(): from flamingo_tools.segmentation import run_unet_prediction @@ -21,6 +19,7 @@ def main(): parser.add_argument("-m", "--model", required=True) parser.add_argument("-k", "--input_key", default=None) parser.add_argument("-s", "--scale", default=None, type=float, help="Downscale the image by the given factor.") + parser.add_argument("-b", "--block_shape", default=None, type=int, nargs=3) args = parser.parse_args() @@ -34,11 +33,17 @@ def main(): have_cuda = torch.cuda.is_available() if args.input_key is None: - block_shape = (64, 256, 256) if have_cuda else (64, 64, 64) + if args.block_shape is None: + block_shape = (64, 256, 256) if have_cuda else (64, 64, 64) + else: + block_shape = tuple(args.block_shape) halo = (16, 64, 64) if have_cuda else (8, 32, 32) else: - chunks = z5py.File(args.input, "r")[args.input_key].chunks - block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks) + if args.block_shape is None: + chunks = z5py.File(args.input, "r")[args.input_key].chunks + block_shape = tuple([2 * ch for ch in chunks]) if have_cuda else tuple(chunks) + else: + block_shape = tuple(args.block_shape) halo = (16, 64, 64) if have_cuda else (8, 32, 32) run_unet_prediction( diff --git a/scripts/resize_wrongly_scaled_cochleas.py b/scripts/resize_wrongly_scaled_cochleas.py index dc76ca9..5de7064 100644 --- a/scripts/resize_wrongly_scaled_cochleas.py +++ b/scripts/resize_wrongly_scaled_cochleas.py @@ -4,17 +4,17 @@ import multiprocessing as mp from concurrent import futures -import imageio.v3 as imageio import nifty.tools as nt from tqdm import tqdm from elf.wrapper.resized_volume import ResizedVolume from elf.io import open_file +from flamingo_tools.file_utils import read_tif def main(input_path, output_folder, scale, input_key, interpolation_order): if input_path.endswith(".tif"): - input_ = imageio.imread(input_path) + input_ = read_tif(input_path) input_chunks = (128,) * 3 else: input_ = open_file(input_path, "r")[input_key] @@ -51,7 +51,11 @@ def copy_chunk(block_index): output_dataset[volume_index] = data with futures.ThreadPoolExecutor(n_threads) as resize_pool: - list(tqdm(resize_pool.map(copy_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks)) + list(tqdm( + resize_pool.map(copy_chunk, range(blocking.numberOfBlocks)), + total=blocking.numberOfBlocks, + desc=f"Resizing volume from shape {shape} to {new_shape}" + )) if __name__ == "__main__": @@ -59,14 +63,14 @@ def copy_chunk(block_index): parser = argparse.ArgumentParser( description="Script for resizing microscoopy data in n5 format.") - parser.add_argument('input_file', type=str, help="Input file") + parser.add_argument("input_file", type=str, help="Input file") parser.add_argument( - 'output_folder', type=str, help="Output folder. Default resized output is _resized.n5" + "output_folder", type=str, help="Output folder. Default resized output is _resized.n5" ) - parser.add_argument('-s', "--scale", type=float, default=0.38, help="Scale of input. Re-scaled to 1.") - parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for n5 file.") - parser.add_argument('-i', "--interpolation_order", type=float, default=3, help="Interpolation order.") + parser.add_argument("-s", "--scale", type=float, default=0.38, help="Scale of input. Re-scaled to 1.") + parser.add_argument("-k", "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for n5 file.") + parser.add_argument("-i", "--interpolation_order", type=float, default=3, help="Interpolation order.") args = parser.parse_args() diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 575ed8e..518123b 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -116,7 +116,7 @@ def main(): run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name # Parameters for training on A100. - n_iterations = 1e5 + n_iterations = int(1e5) patch_shape = (64, 128, 128) # The U-Net. From d9ae457949d9a6f98505b19bf8414b0c63373756 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 21 Apr 2025 18:10:10 +0200 Subject: [PATCH 2/7] Update postprocessing code --- flamingo_tools/file_utils.py | 2 +- flamingo_tools/segmentation/__init__.py | 2 +- flamingo_tools/segmentation/postprocessing.py | 223 ++++++++++++------ .../segmentation/unet_prediction.py | 4 +- scripts/prediction/postprocess_seg.py | 7 +- 5 files changed, 156 insertions(+), 82 deletions(-) diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index fc4e15d..2d2ac04 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -64,7 +64,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: # TODO: Update the any types: # The first should be the type of a zarr s3 store, -def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.array_like: +def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.typing.ArrayLike: """Read flamingo image data, stored in various formats. Args: diff --git a/flamingo_tools/segmentation/__init__.py b/flamingo_tools/segmentation/__init__.py index 172787c..5c57721 100644 --- a/flamingo_tools/segmentation/__init__.py +++ b/flamingo_tools/segmentation/__init__.py @@ -1,2 +1,2 @@ from .unet_prediction import run_unet_prediction -from .postprocessing import filter_isolated_objects +from .postprocessing import filter_segmentation diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 06f896d..72d019e 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,97 +1,174 @@ -import numpy as np -import vigra import multiprocessing as mp from concurrent import futures +from typing import Callable, Tuple, Optional -from skimage import measure +import elf.parallel as parallel +import numpy as np +import nifty.tools as nt +import pandas as pd +import vigra + +from elf.io import open_file from scipy.spatial import distance from scipy.sparse import csr_matrix -from tqdm import tqdm +from scipy.spatial import cKDTree, ConvexHull +from skimage import measure from sklearn.neighbors import NearestNeighbors +from tqdm import tqdm -import elf.parallel as parallel -from elf.io import open_file -import nifty.tools as nt +# +# Spatial statistics: +# Three different spatial statistics implementations that +# can be used as the basis of a filtering criterion. +# -def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True): - """Calculate average distance of n nearest neighbors. - :param DataFrame tsv_table: - :param int n_neighbors: Number of nearest neighbors - :param bool expand_table: Flag for expanding DataFrame - :returns: List of average distances - :rtype: list - """ - centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"])) +def nearest_neighbor_distance(table: pd.DataFrame, n_neighbors: int = 10) -> np.ndarray: + """Compute the average distance to the n nearest neighbors. + + Args: + table: The table with the centroid coordinates. + n_neighbors: The number of neighbors to take into account for the distance computation. - coordinates = np.array(centroids) + Returns: + The average distances to the n nearest neighbors. + """ + centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + centroids = np.array(centroids) - # nearest neighbor is always itself, so n_neighbors+=1 - nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates) - distances, indices = nbrs.kneighbors(coordinates) + # Nearest neighbor is always itself, so n_neighbors+=1. + nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(centroids) + distances, indices = nbrs.kneighbors(centroids) # Average distance to nearest neighbors - distance_avg = [sum(d) / len(d) for d in distances[:, 1:]] + distance_avg = np.array([sum(d) / len(d) for d in distances[:, 1:]]) + return distance_avg - if expand_table: - tsv_table['distance_nn'+str(n_neighbors)] = distance_avg - return distance_avg +def local_ripleys_k(table: pd.DataFrame, radius: float = 15, volume: Optional[float] = None) -> np.ndarray: + """Compute the local Ripley's K function for each point in a 2D / 3D. + Args: + table: The table with the centroid coordinates. + radius: The radius within which to count neighboring points. + volume: The area (2D) or volume (3D) of the study region. If None, it is estimated from the convex hull. + + Returns: + An array containing the local K values for each point. + """ + points = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + points = np.array(points) + n_points, dim = points.shape + + if dim not in (2, 3): + raise ValueError("Points array must be of shape (n_points, 2) or (n_points, 3).") + + # Estimate area/volume if not provided. + if volume is None: + hull = ConvexHull(points) + volume = hull.volume # For 2D, 'volume' is area; for 3D, it's volume. + + # Compute point density. + density = n_points / volume + + # Build a KD-tree for efficient neighbor search. + tree = cKDTree(points) + + # Count neighbors within the specified radius for each point + counts = tree.query_ball_point(points, r=radius) + local_counts = np.array([len(c) - 1 for c in counts]) # Exclude the point itself + + # Normalize by density to get local K values + local_k = local_counts / density + return local_k -def filter_isolated_objects( - segmentation, output_path, tsv_table=None, - distance_threshold=15, neighbor_threshold=5, min_size=1000, - output_key="segmentation_postprocessed", -): - """Postprocessing step to filter isolated objects from a segmentation. - Instance segmentations are filtered if they have fewer neighbors - than a given threshold in a given distance around them. - Additionally, size filtering is possible if a TSV file is supplied. - - :param dataset segmentation: Dataset containing the segmentation - :param str out_path: Output path for postprocessed segmentation - :param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format - :param int distance_threshold: Distance in micrometer to check for neighbors - :param int neighbor_threshold: Minimal number of neighbors for filtering - :param int min_size: Minimal number of pixels for filtering small instances - :param str output_key: Output key for postprocessed segmentation +def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray: + """Compute the number of neighbors within a given radius. + + Args: + table: The table with the centroid coordinates. + radius: The radius within which to count neighboring points. + + Returns: + An array containing the number of neighbors within the given radius. """ - if tsv_table is not None: - n_pixels = tsv_table["n_pixels"].to_list() - label_ids = tsv_table["label_id"].to_list() - centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"])) - n_ids = len(label_ids) - - # filter out cells smaller than min_size - if min_size is not None: - min_size_label_ids = [l for (l, n) in zip(label_ids, n_pixels) if n <= min_size] - centroids = [c for (c, l) in zip(centroids, label_ids) if l not in min_size_label_ids] - label_ids = [int(lid) for lid in label_ids if lid not in min_size_label_ids] - - coordinates = np.array(centroids) - label_ids = np.array(label_ids) - - else: - segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True) - props = measure.regionprops(segmentation) - coordinates = np.array([prop.centroid for prop in props]) - label_ids = np.unique(segmentation)[1:] - - # Calculate pairwise distances and convert to a square matrix - dist_matrix = distance.pdist(coordinates) + points = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) + points = np.array(points) + + dist_matrix = distance.pdist(points) dist_matrix = distance.squareform(dist_matrix) - # Create sparse matrix of connections within the threshold distance - sparse_matrix = csr_matrix(dist_matrix < distance_threshold, dtype=int) + # Create sparse matrix of connections within the threshold distance. + sparse_matrix = csr_matrix(dist_matrix < radius, dtype=int) - # Sum each row to count neighbors + # Sum each row to count neighbors. neighbor_counts = sparse_matrix.sum(axis=1) + return np.array(neighbor_counts) + + +# +# Filtering function: +# Filter the segmentation based on a spatial statistics from above. +# + + +def _compute_table(segmentation): + segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True) + props = measure.regionprops(segmentation) + coordinates = np.array([prop.centroid for prop in props])[1:] + label_ids = np.unique(segmentation)[1:] + sizes = np.array([prop.area for prop in props])[1:] + table = pd.DataFrame({ + "label_id": label_ids, + "n_pixels": sizes, + "anchor_x": coordinates[:, 2], + "anchor_y": coordinates[:, 1], + "anchor_z": coordinates[:, 0], + }) + return table + + +def filter_segmentation( + segmentation: np.typing.ArrayLike, + output_path: str, + spatial_statistics: Callable, + threshold: float, + min_size: int = 1000, + table: Optional[pd.DataFrame] = None, + output_key: str = "segmentation_postprocessed", +) -> Tuple[int, int]: + """Postprocessing step to filter isolated objects from a segmentation. - filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze() - filter_ids = label_ids[filter_mask] + Instance segmentations are filtered based on spatial statistics and a threshold. + In addition, objects smaller than a given size are filtered out. + + Args: + segmentation: Dataset containing the segmentation + output_path: Output path for postprocessed segmentation + spatial_statistics: + threshold: Distance in micrometer to check for neighbors + min_size: Minimal number of pixels for filtering small instances + table: + output_key: Output key for postprocessed segmentation + + Returns: + n_ids + n_ids_filtered + """ + # Compute the table on the fly. + # NOTE: this currently doesn't work for large segmentations. + if table is None: + table = _compute_table(segmentation) + n_ids = len(table) + + # First apply the size filter. + table = table[table.n_pixels > min_size] + stat_values = spatial_statistics(table) + + keep_mask = np.array(stat_values > threshold).squeeze() + keep_ids = table.label_id.values[keep_mask] shape = segmentation.shape block_shape = (128, 128, 128) @@ -100,7 +177,6 @@ def filter_isolated_objects( blocking = nt.blocking([0] * len(shape), shape, block_shape) output = open_file(output_path, mode="a") - output_dataset = output.create_dataset( output_key, shape=shape, dtype=segmentation.dtype, chunks=chunks, compression="gzip" @@ -112,17 +188,16 @@ def filter_chunk(block_id): block = blocking.getBlock(block_id) volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) data = segmentation[volume_index] - data[np.isin(data, filter_ids)] = 0 + data[np.isin(data, keep_ids)] = 0 output_dataset[volume_index] = data # Limit the number of cores for parallelization. n_threads = min(16, mp.cpu_count()) - with futures.ThreadPoolExecutor(n_threads) as filter_pool: list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks)) seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive( - output_dataset, start_label=1, keep_zeros=True, block_shape=(128, 128, 128) + output_dataset, start_label=1, keep_zeros=True, block_shape=block_shape ) - return seg_filtered, n_ids, n_ids_filtered + return n_ids, n_ids_filtered diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 3cfb1b0..a5021c6 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -25,7 +25,7 @@ from tqdm import tqdm import flamingo_tools.s3_utils as s3_utils -from flamingl_tools.file_utils import read_image_data +from flamingo_tools.file_utils import read_image_data class SelectChannel(SimpleTransformationWrapper): @@ -35,7 +35,7 @@ class SelectChannel(SimpleTransformationWrapper): volume: The array-like input dataset. channel: The channel that will be selected. """ - def __init__(self, volume: np.array_like, channel: int): + def __init__(self, volume: np.typing.ArrayLike, channel: int): self.channel = channel super().__init__(volume, lambda x: x[self.channel], with_channels=True) diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 04c05da..b3f9eb6 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -1,16 +1,15 @@ import argparse import os -import sys import pandas as pd import zarr -sys.path.append("../..") - import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.segmentation import filter_segmentation + +# TODO needs updates def main(): - from flamingo_tools.segmentation import filter_isolated_objects parser = argparse.ArgumentParser( description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") From 2ec1a8315d9bf78bce1d69f152b774c9d7a5e0b4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 21 Apr 2025 18:36:29 +0200 Subject: [PATCH 3/7] Fix file mode in tifffile --- flamingo_tools/file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index 2d2ac04..b861fc0 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -55,7 +55,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: The memory-mapped data. If not possible to memmap, the data in memory. """ try: - x = tifffile.memmap(file_path, "r") + x = tifffile.memmap(file_path) except ValueError: warnings.warn(f"Cannot memmap the tif file at {file_path}. Fall back to loading it into memory.") x = imageio.imread(file_path) From 4e64195357a7cc30f4cf40d4def5aba1ec353107 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Tue, 22 Apr 2025 14:55:11 +0200 Subject: [PATCH 4/7] Fixed default service endpoint --- flamingo_tools/s3_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index 92c4e4f..df52299 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -10,7 +10,7 @@ # Dedicated bucket for cochlea lightsheet project MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" -SERVICE_ENDPOINT = "https://s3.gwdg.de/" +SERVICE_ENDPOINT = "https://s3.fs.gwdg.de/" BUCKET_NAME = "cochlea-lightsheet" DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials") From 3f6c95ad850899ef84226dcc31e8748c4c224e83 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 23 Apr 2025 17:44:11 +0200 Subject: [PATCH 5/7] Linter, unify docstring, postprocessing --- flamingo_tools/file_utils.py | 6 +- flamingo_tools/s3_utils.py | 71 ++++++++---- flamingo_tools/segmentation/postprocessing.py | 10 +- .../segmentation/unet_prediction.py | 77 ++++++++----- scripts/convert_tif_to_n5.py | 23 ++-- scripts/extract_block.py | 101 ++++++++++-------- scripts/prediction/count_cells.py | 32 ++++-- scripts/prediction/expand_seg_table.py | 88 +++++++++++---- scripts/prediction/postprocess_seg.py | 84 +++++++++++---- scripts/resize_wrongly_scaled_cochleas.py | 21 +++- 10 files changed, 355 insertions(+), 158 deletions(-) diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index b861fc0..f121923 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Optional, Union +from typing import Optional, Union import imageio.v3 as imageio import numpy as np @@ -62,9 +62,7 @@ def read_tif(file_path: str) -> Union[np.ndarray, np.memmap]: return x -# TODO: Update the any types: -# The first should be the type of a zarr s3 store, -def read_image_data(input_path: Union[str, Any], input_key: Optional[str]) -> np.typing.ArrayLike: +def read_image_data(input_path: Union[str, zarr.storage.FSStore], input_key: Optional[str]) -> np.typing.ArrayLike: """Read flamingo image data, stored in various formats. Args: diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index df52299..4dc6a23 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -19,11 +19,17 @@ def check_s3_credentials( bucket_name: Optional[str], service_endpoint: Optional[str], credential_file: Optional[str] ) -> Tuple[str, str, str]: - """Check if S3 parameter and credentials were set either as input variables or as environment variables. + """Check if S3 parameter and credentials were set as input arguments, as environment variables, or as globals. Args: + bucket_name: S3 bucket name + service_endpoint: S3 service endpoint + credential_file: Credential file containing access key and secret key Returns: + bucket_name + service_endpoint + credential_file """ if bucket_name is None: bucket_name = os.getenv("BUCKET_NAME") @@ -87,49 +93,78 @@ def get_s3_path( bucket_name: Optional[str] = None, service_endpoint: Optional[str] = None, credential_file: Optional[str] = None, - # ) -> Tuple[]: -): +) -> Tuple[zarr.storage.FSStore, s3fs.core.S3FileSystem]: """Get S3 path for a file or folder and file system based on S3 parameters and credentials. + + Args: + input_path: Inputh path in S3 bucket + bucket_name: S3 bucket name + service_endpoint: S3 service endpoint + credential_file: Credential file containing access key and secret key + + Returns: + s3_path + s3_filesystem """ bucket_name, service_endpoint, credential_file = check_s3_credentials( bucket_name, service_endpoint, credential_file ) - fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) + s3_filesystem = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) zarr_path = f"{bucket_name}/{input_path}" - if not fs.exists(zarr_path): + if not s3_filesystem.exists(zarr_path): print(f"Error: S3 path {zarr_path} does not exist!") - s3_path = zarr.storage.FSStore(zarr_path, fs=fs) + s3_path = zarr.storage.FSStore(zarr_path, fs=s3_filesystem) - return s3_path, fs + return s3_path, s3_filesystem -def read_s3_credentials(credential_file): - key, secret = None, None +def read_s3_credentials(credential_file: str) -> Tuple[str, str]: + """Read access key amd secret key from credential file. + + Args: + credential_file: File path to credentials + + Returns: + access_key + secret_key + """ + access_key, secret_key = None, None with open(credential_file) as f: for line in f: if line.startswith("aws_access_key_id"): - key = line.rstrip("\n").strip().split(" ")[-1] + access_key = line.rstrip("\n").strip().split(" ")[-1] if line.startswith("aws_secret_access_key"): - secret = line.rstrip("\n").strip().split(" ")[-1] - if key is None or secret is None: + secret_key = line.rstrip("\n").strip().split(" ")[-1] + if access_key is None or secret_key is None: raise ValueError(f"Invalid credential file {credential_file}") - return key, secret + return access_key, secret_key -def create_s3_target(url=None, anon=False, credential_file=None): +def create_s3_target( + url: Optional[str] = None, + anon: Optional[str] = False, + credential_file: Optional[str] = None, +) -> s3fs.core.S3FileSystem: """Create file system for S3 bucket based on a service endpoint and an optional credential file. - If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. + + Args: + url: Service endpoint for S3 bucket + anon: Option for anon argument of S3FileSystem + credential_file: File path to credentials + + Returns: + s3_filesystem """ client_kwargs = {"endpoint_url": SERVICE_ENDPOINT if url is None else url} if credential_file is not None: key, secret = read_s3_credentials(credential_file) - fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) + s3_filesystem = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) else: - fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) - return fs + s3_filesystem = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) + return s3_filesystem diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 72d019e..9c5ca97 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -113,6 +113,8 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray: # Filter the segmentation based on a spatial statistics from above. # +# FIXME: functions causes ValueError by using arrays of different lengths + def _compute_table(segmentation): segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True) @@ -138,6 +140,7 @@ def filter_segmentation( min_size: int = 1000, table: Optional[pd.DataFrame] = None, output_key: str = "segmentation_postprocessed", + **spatial_statistics_kwargs, ) -> Tuple[int, int]: """Postprocessing step to filter isolated objects from a segmentation. @@ -147,11 +150,12 @@ def filter_segmentation( Args: segmentation: Dataset containing the segmentation output_path: Output path for postprocessed segmentation - spatial_statistics: + spatial_statistics: Function to calculate density measure for elements of segmentation threshold: Distance in micrometer to check for neighbors min_size: Minimal number of pixels for filtering small instances - table: + table: Dataframe of segmentation table output_key: Output key for postprocessed segmentation + spatial_statistics_kwargs: Arguments for spatial statistics function Returns: n_ids @@ -165,7 +169,7 @@ def filter_segmentation( # First apply the size filter. table = table[table.n_pixels > min_size] - stat_values = spatial_statistics(table) + stat_values = spatial_statistics(table, **spatial_statistics_kwargs) keep_mask = np.array(stat_values > threshold).squeeze() keep_ids = table.label_id.values[keep_mask] diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index a5021c6..05b3640 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -341,13 +341,27 @@ def run_unet_prediction( def run_unet_prediction_preprocess_slurm( - input_path, input_key, output_folder, - s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, -): + input_path: str, + input_key: Optional[str], + output_folder: str, + s3: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + s3_credentials: Optional[str] = None, +) -> None: """Pre-processing for the parallel prediction with U-Net models. Masks are stored in mask.zarr in the output folder. The mean and standard deviation are precomputed for later usage during prediction and stored in a JSON file within the output folder as mean_std.json. + + Args: + input_path: The path to the input data. + input_key: The key / internal path of the image data. + output_folder: The output folder for storing the segmentation related data. + s3: Flag for considering input_path fo S3 bucket. + s3_bucket_name: S3 bucket name. + s3_service_endpoint: S3 service endpoint. + s3_credentials: File path to credentials for S3 bucket. """ if s3 is not None: input_path, fs = s3_utils.get_s3_path( @@ -361,26 +375,35 @@ def run_unet_prediction_preprocess_slurm( def run_unet_prediction_slurm( - input_path, input_key, output_folder, model_path, - scale=None, - block_shape=None, halo=None, prediction_instances=1, - s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, -): - """ - Run prediction of distance U-Net for data stored locally or on an S3 bucket. - - :param str input_path: File path to input data - :param str input_key: Input key for data in ome.zarr format - :param str output_folder: Output folder for prediction.zarr - :param str model_path: File path to distance U-Net model - :param float scale: - :param tuple block_shape: - :param tuple halo: - :param int prediction_instances: Number of workers for parallel computation within slurm array - :param bool s3: Flag for accessing data on S3 bucket - :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported - :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported - :param str s3_credentials: Path to file containing S3 credentials + input_path: str, + input_key: Optional[str], + output_folder: str, + model_path: str, + scale: Optional[float] = None, + block_shape: Optional[Tuple[int, int, int]] = None, + halo: Optional[Tuple[int, int, int]] = None, + prediction_instances: Optional[int] = 1, + s3: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + s3_credentials: Optional[str] = None, +) -> None: + """Run prediction of distance U-Net for data stored locally or on an S3 bucket. + + Args: + input_path: The path to the input data. + input_key: The key / internal path of the image data. + output_folder: The output folder for storing the segmentation related data. + model_path: The path to the model to use for segmentation. + scale: A factor to rescale the data before prediction. + By default the data will not be rescaled. + block_shape: The block-shape for running the prediction. + halo: The halo (= block overlap) to use for prediction. + prediction_instances: Number of instances for parallel prediction. + s3: Flag for considering input_path fo S3 bucket. + s3_bucket_name: S3 bucket name. + s3_service_endpoint: S3 service endpoint. + s3_credentials: File path to credentials for S3 bucket. """ os.makedirs(output_folder, exist_ok=True) prediction_instances = int(prediction_instances) @@ -417,7 +440,13 @@ def run_unet_prediction_slurm( # does NOT need GPU, FIXME: only run on CPU -def run_unet_segmentation_slurm(output_folder, min_size): +def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None: + """Create segmentation from prediction. + + Args: + output_folder: The output folder for storing the segmentation related data. + min_size: The minimal size of segmented objects in the output. + """ min_size = int(min_size) pmap_out = os.path.join(output_folder, "predictions.zarr") segmentation_impl(pmap_out, output_folder, min_size=min_size) diff --git a/scripts/convert_tif_to_n5.py b/scripts/convert_tif_to_n5.py index b8844f6..395f7d0 100644 --- a/scripts/convert_tif_to_n5.py +++ b/scripts/convert_tif_to_n5.py @@ -1,16 +1,17 @@ -import os, sys import argparse +import os import pybdv +import sys import imageio.v3 as imageio -def main(input_path, output_path): - """ - Convert tif file to n5 format. +def main(input_path: str, output_path: str = None) -> None: + """Convert tif file to n5 format. If no output_path is supplied, the output file is created in the same directory as the input. - :param str input_path: Input tif - :param str output_path: Output path for n5 format + Args: + input_path: Input path to tif. + output_path: Output path for n5 format. """ if not os.path.isfile(input_path): sys.exit("Input file does not exist.") @@ -22,19 +23,21 @@ def main(input_path, output_path): input_dir = input_path.split(basename)[0] input_dir = os.path.abspath(input_dir) - if "" == output_path: + if output_path is None: output_path = os.path.join(input_dir, basename + ".n5") + img = imageio.imread(input_path) pybdv.make_bdv(img, output_path) + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Script to transform file from tif into n5 format.") - parser.add_argument('input', type=str, help="Input file") - parser.add_argument('-o', "--output", type=str, default="", help="Output file. Default: .n5") + parser.add_argument('-i', '--input', required=True, type=str, help="Input file") + parser.add_argument('-o', "--output", type=str, default=None, help="Output file. Default: .n5") args = parser.parse_args() - main(args.input, args.output) \ No newline at end of file + main(args.input, args.output) diff --git a/scripts/extract_block.py b/scripts/extract_block.py index 3f5ab3a..688f794 100644 --- a/scripts/extract_block.py +++ b/scripts/extract_block.py @@ -1,57 +1,56 @@ -import os +"""This script extracts data around an input center coordinate in a given ROI halo. +""" import argparse +import json +import os +from typing import Optional, List + import numpy as np import zarr import flamingo_tools.s3_utils as s3_utils -""" -This script extracts data around an input center coordinate in a given ROI halo. - -The support for using an S3 bucket is currently limited to the lightsheet-cochlea bucket with the endpoint url https://s3.fs.gwdg.de. -If more use cases appear, the script will be generalized. -The usage requires the export of the access and the secret access key within the environment before executing the script. -run the following commands in the shell of your choice, or add them to your ~/.bashrc: -export AWS_ACCESS_KEY_ID= -export AWS_SECRET_ACCESS_KEY= -""" - def main( - input_file, output_dir, coords, input_key, resolution, roi_halo, - s3, s3_credentials, s3_bucket_name, s3_service_endpoint, - ): - """ - - :param str input_file: File path to input folder in n5 format - :param str output_dir: output directory for saving cropped n5 file as _crop.n5 - :param str input_key: Key for accessing volume in n5 format, e.g. 'setup0/s0' - :param float resolution: Resolution of input data in micrometer - :param str coords: Center coordinates of extracted 3D volume in format 'x,y,z' - :param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z' - :param bool s3: Flag for using an S3 bucket - :param str s3_credentials: Path to file containing S3 credentials - :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported - :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported + input_path: str, + coords: List[int], + output_dir: str = None, + input_key: str = "setup0/timepoint0/s0", + resolution: float = 0.38, + roi_halo: List[int] = [128, 128, 64], + s3: Optional[bool] = False, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, +): + """Extract block around coordinate from input data according to a given halo. + Either from a local file or from an S3 bucket. + + Args: + input_path: Input folder in n5 / ome-zarr format. + coords: Center coordinates of extracted 3D volume. + output_dir: Output directory for saving output as _crop.n5. Default: input directory. + roi_halo: ROI halo of extracted 3D volume. + s3: Flag for considering input_path for S3 bucket. + s3_bucket_name: S3 bucket name. + s3_service_endpoint: S3 service endpoint. + s3_credentials: File path to credentials for S3 bucket. """ - coords = [int(r) for r in coords.split(",")] - roi_halo = [int(r) for r in roi_halo.split(",")] - coord_string = "-".join([str(c) for c in coords]) # Dimensions are inversed to view in MoBIE (x y z) -> (z y x) coords.reverse() roi_halo.reverse() - input_content = list(filter(None, input_file.split("/"))) + input_content = list(filter(None, input_path.split("/"))) if s3: basename = input_content[0] + "_" + input_content[-1].split(".")[0] else: basename = "".join(input_content[-1].split(".")[:-1]) - input_dir = input_file.split(basename)[0] + input_dir = input_path.split(basename)[0] input_dir = os.path.abspath(input_dir) if output_dir == "": @@ -66,40 +65,52 @@ def main( roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: - s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + s3_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) with zarr.open(s3_path, mode="r") as f: raw = f[input_key][roi] else: - with zarr.open(input_file, mode="r") as f: + with zarr.open(input_path, mode="r") as f: raw = f[input_key][roi] with zarr.open(output_file, mode="w") as f_out: f_out.create_dataset("raw", data=raw, compression="gzip") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Script to extract region of interest (ROI) block around center coordinate.") - parser.add_argument('input', type=str, help="Input file in n5 format.") - parser.add_argument('-o', "--output", type=str, default="", help="Output directory") - parser.add_argument('-c', "--coord", type=str, required=True, help="3D coordinate in format 'x,y,z' as center of extracted block.") + parser.add_argument('-i', '--input', type=str, help="Input file in n5 / ome-zarr format.") + parser.add_argument('-o', "--output", type=str, default="", help="Output directory.") + parser.add_argument('-c', "--coord", type=str, required=True, + help="3D coordinate as center of extracted block, json-encoded.") - parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", help="Input key for data in input file") - parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer") + parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0", + help="Input key for data in input file.") + parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer.") - parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'") + parser.add_argument("--roi_halo", type=str, default="[128,128,64]", + help="ROI halo around center coordinate, json-encoded.") - parser.add_argument("--s3", action="store_true", help="Use S3 bucket") - parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials") - parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name") - parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint") + parser.add_argument("--s3", action="store_true", help="Use S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") args = parser.parse_args() + coords = json.loads(args.coord) + roi_halo = json.loads(args.roi_halo) + main( - args.input, args.output, args.coord, args.input_key, args.resolution, args.roi_halo, + args.input, coords, args.output, args.input_key, args.resolution, roi_halo, args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, ) diff --git a/scripts/prediction/count_cells.py b/scripts/prediction/count_cells.py index 087dd79..34ce5af 100644 --- a/scripts/prediction/count_cells.py +++ b/scripts/prediction/count_cells.py @@ -8,17 +8,25 @@ import flamingo_tools.s3_utils as s3_utils + def main(): parser = argparse.ArgumentParser() - parser.add_argument("-o", "--output_folder", type=str, default=None, help="Output directory containing segmentation.zarr") - - parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="Input key for data in input file") - parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") - - parser.add_argument("--s3_input", default=None, help="Input file path on S3 bucket") - parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials") - parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name") - parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint") + parser.add_argument("-o", "--output_folder", type=str, default=None, + help="Output directory containing segmentation.zarr.") + + parser.add_argument('-k', "--input_key", type=str, default="segmentation", + help="The key / internal path of the segmentation.") + parser.add_argument("-m", "--min_size", type=int, default=1000, + help="Minimal number of voxel size for counting object.") + + parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") args = parser.parse_args() @@ -28,7 +36,10 @@ def main(): raise ValueError("Either provide an output_folder containing 'segmentation.zarr' or an S3 input.") if args.s3_input is not None: - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + with zarr.open(s3_path, mode="r") as f: dataset = f[args.input_key] @@ -44,5 +55,6 @@ def main(): counts = counts[counts > min_size] print("Number of objects:", len(counts)) + if __name__ == "__main__": main() diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py index dc080fd..1b6d3f7 100644 --- a/scripts/prediction/expand_seg_table.py +++ b/scripts/prediction/expand_seg_table.py @@ -1,25 +1,40 @@ import argparse +import json +from typing import Optional, List import pandas as pd import flamingo_tools.segmentation.postprocessing as postprocessing import flamingo_tools.s3_utils as s3_utils + def main( - in_path, out_path, n_neighbors=None, - s3=False, s3_credentials=None, s3_bucket_name=None, s3_service_endpoint=None, - ): - """ + in_path: str, + out_path: str, + n_neighbors: Optional[List[int]] = None, + local_ripley_radius: Optional[List[int]] = None, + r_neighbors: Optional[List[int]] = None, + s3: Optional[bool] = False, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, +): + """Expand TSV table with additional parameters for postprocessing. - :param str input_file: Path to table in TSV format - :param str out_path: Path to save output - :param bool s3: Flag for using an S3 bucket - :param str s3_credentials: Path to file containing S3 credentials - :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported - :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported + Args: + in_path: Path to table in TSV format. + out_path: Path to save output. + n_neighbors: Value(s) for nearest neighbor distance. + local_ripley_radius: Value(s) for calculating local Ripley's K function. + r_neighbors: Value(s) for radii for calculating number of neighbors in range. + s3: Flag for considering in_path fo S3 bucket. + s3_bucket_name: S3 bucket name. + s3_service_endpoint: S3 service endpoint. + s3_credentials: File path to credentials for S3 bucket. """ if s3: - tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") else: @@ -27,33 +42,62 @@ def main( tsv_table = pd.read_csv(f, sep="\t") if n_neighbors is not None: - nn_list = [int(n) for n in n_neighbors.split(",")] - for n_neighbor in nn_list: + for n_neighbor in n_neighbors: if n_neighbor >= len(tsv_table): - raise ValueError(f"Number of neighbors: {n_neighbor} exceeds number of elements in dataframe: {len(tsv_table)}.") + raise ValueError(f"Number of neighbors {n_neighbor} exceeds elements in dataframe: {len(tsv_table)}.") + + distance_avg = postprocessing.nearest_neighbor_distance(table=tsv_table, n_neighbors=n_neighbor) + tsv_table['distance_nn'+str(n_neighbor)] = list(distance_avg) + + if local_ripley_radius is not None: + for lr_radius in local_ripley_radius: + local_k = postprocessing.local_ripleys_k(table=tsv_table, radius=lr_radius) + tsv_table['local_ripley_radius'+str(lr_radius)] = list(local_k) - _ = postprocessing.distance_nearest_neighbors(tsv_table=tsv_table, n_neighbors=n_neighbor, expand_table=True) + if r_neighbors is not None: + for r_neighbor in r_neighbors: + neighbor_counts = postprocessing.neighbors_in_radius(table=tsv_table, radius=r_neighbor) + neighbor_counts = list(neighbor_counts) + neighbor_counts = [n[0] for n in neighbor_counts] + tsv_table['neighbors_in_radius'+str(r_neighbor)] = neighbor_counts tsv_table.to_csv(out_path, sep="\t") + if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Script for expanding the segmentation table of MoBIE with additonal parameters. Either locally or on an S3 bucket.") + description="Script for expanding the segmentation table of MoBIE with additonal parameters. " + "Either locally or on an S3 bucket.") parser.add_argument("-i", "--input", required=True) parser.add_argument("-o", "--output", required=True) - parser.add_argument("--n_neighbors", default=None, help="Value(s) for number of nearest neighbors in format 'n1,n2,...,nx'. New columns contain the average distance to nearest neighbors.") + parser.add_argument("--n_neighbors", default=None, + help="Value(s) for calculating distance to 'n' nearest neighbors, json-encoded.") + + parser.add_argument("--local_ripley_radius", default=None, + help="Value(s) for radii for calculating local Ripley's K function, json-encoded.") - parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket") - parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials. Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported") - parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name. Optional if BUCKET_NAME was exported") - parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported") + parser.add_argument("--r_neighbors", default=None, + help="Value(s) for radii for calculating number of neighbors in range, json-encoded.") + + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") args = parser.parse_args() + n_neighbors = json.loads(args.n_neighbors) if args.n_neighbors is not None else None + local_ripley_radius = json.loads(args.local_ripley_radius) if args.local_ripley_radius is not None else None + r_neighbors = json.loads(args.r_neighbors) if args.r_neighbors is not None else None + main( - args.input, args.output, args.n_neighbors, + args.input, args.output, n_neighbors, local_ripley_radius, r_neighbors, args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, ) diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index b3f9eb6..4869361 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -6,6 +6,7 @@ import flamingo_tools.s3_utils as s3_utils from flamingo_tools.segmentation import filter_segmentation +from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius # TODO needs updates @@ -14,34 +15,78 @@ def main(): parser = argparse.ArgumentParser( description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") - parser.add_argument("-o", "--output_folder", required=True) + parser.add_argument("-o", "--output_folder", type=str, required=True) - parser.add_argument("-t", "--tsv", default=None, help="TSV-file in MoBIE format which contains information about the segmentation") - parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="Input key for data in input file") - parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", help="Output key for data in input file") + parser.add_argument("-t", "--tsv", type=str, default=None, + help="TSV-file in MoBIE format which contains information about segmentation.") + parser.add_argument('-k', "--input_key", type=str, default="segmentation", + help="The key / internal path of the segmentation.") + parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", + help="The key / internal path of the output.") - parser.add_argument("--s3_input", default=None, help="Input file path on S3 bucket") - parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials. Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported") - parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name. Optional if BUCKET_NAME was exported") - parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported") + parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") - parser.add_argument("--min_size", type=int, default=None, help="Minimal number of voxel size for counting object") - parser.add_argument("--distance_threshold", type=int, default=15, help="Distance in micrometer to check for neighbors") - parser.add_argument("--neighbor_threshold", type=int, default=5, help="Minimal number of neighbors for filtering") + parser.add_argument("--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") + + parser.add_argument("--n_neighbors", type=int, default=None, + help="Value for calculating distance to 'n' nearest neighbors.") + + parser.add_argument("--local_ripley_radius", type=int, default=None, + help="Value for radius for calculating local Ripley's K function.") + + parser.add_argument("--r_neighbors", type=int, default=None, + help="Value for radius for calculating number of neighbors in range.") args = parser.parse_args() + postprocess_functions = [nearest_neighbor_distance, local_ripleys_k, neighbors_in_radius] + function_keywords = ["n_neighbors", "radius", "radius"] + postprocess_options = [args.n_neighbors, args.local_ripley_radius, args.r_neighbors] + default_thresholds = [15, 20, 20] + + def create_spatial_statistics_dict(functions, keyword, options, threshold): + spatial_statistics_dict = [] + for f, o, k, t in zip(functions, keyword, options, threshold): + dic = {"function": f, "keyword": k, "argument": o, "threshold": t} + spatial_statistics_dict.append(dic) + return spatial_statistics_dict + + spatial_statistics_dict = create_spatial_statistics_dict(postprocess_functions, postprocess_options, + function_keywords, default_thresholds) + + if sum(x["argument"] is not None for x in spatial_statistics_dict) == 0: + raise ValueError("Choose a postprocess function from 'n_neighbors, 'local_ripley_radius', or 'r_neighbors'.") + elif sum(x["argument"] is not None for x in spatial_statistics_dict) > 1: + raise ValueError("The script only supports a single postprocess function.") + else: + for d in spatial_statistics_dict: + if d["argument"] is not None: + spatial_statistics = d["function"] + spatial_statistics_kwargs = {d["keyword"]: d["argument"]} + threshold = d["threshold"] + seg_path = os.path.join(args.output_folder, "segmentation.zarr") - tsv_table=None + tsv_table = None if args.s3_input is not None: - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) with zarr.open(s3_path, mode="r") as f: segmentation = f[args.input_key] if args.tsv is not None: - tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) + tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") @@ -53,13 +98,14 @@ def main(): with open(args.tsv, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") - seg_filtered, n_pre, n_post = filter_isolated_objects( - segmentation, output_path=seg_path, tsv_table=tsv_table, min_size=args.min_size, - distance_threshold=args.distance_threshold, neighbor_threshold=args.neighbor_threshold, - output_key=args.output_key, - ) + n_pre, n_post = filter_segmentation(segmentation, output_path=seg_path, + spatial_statistics=spatial_statistics, + threshold=threshold, + min_size=args.min_size, table=tsv_table, + output_key=args.output_key, **spatial_statistics_kwargs) print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") + if __name__ == "__main__": main() diff --git a/scripts/resize_wrongly_scaled_cochleas.py b/scripts/resize_wrongly_scaled_cochleas.py index 5de7064..174741d 100644 --- a/scripts/resize_wrongly_scaled_cochleas.py +++ b/scripts/resize_wrongly_scaled_cochleas.py @@ -12,7 +12,22 @@ from flamingo_tools.file_utils import read_tif -def main(input_path, output_folder, scale, input_key, interpolation_order): +def main( + input_path: str, + output_folder: str, + scale: float = 0.38, + input_key: str = "setup0/timepoint0/s0", + interpolation_order: int = 3 +): + """Resize wrongly scaled cochleas. + + Args: + input_path: Input path to tif file or n5 folder. + output_folder: Output folder for rescaled data in n5 format. + scale: Scale of output data. + input_key: The key / internal path of the image data. + interpolation_order: Interpolation order for resizing function. + """ if input_path.endswith(".tif"): input_ = read_tif(input_path) input_chunks = (128,) * 3 @@ -63,9 +78,9 @@ def copy_chunk(block_index): parser = argparse.ArgumentParser( description="Script for resizing microscoopy data in n5 format.") - parser.add_argument("input_file", type=str, help="Input file") + parser.add_argument("input_file", type=str, required=True, help="Input tif file or n5 folder.") parser.add_argument( - "output_folder", type=str, help="Output folder. Default resized output is _resized.n5" + "output_folder", type=str, help="Output folder. Default resized output is '_resized.n5'." ) parser.add_argument("-s", "--scale", type=float, default=0.38, help="Scale of input. Re-scaled to 1.") From d26d848323282a894f6d59f61073590820b3cc9e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 23 Apr 2025 21:44:44 +0200 Subject: [PATCH 6/7] Fix issue in table computation and add tests --- flamingo_tools/segmentation/postprocessing.py | 12 ++--- test/test_segmentation/test_postprocessing.py | 49 +++++++++++++++++++ .../test_segmentation/test_unet_prediction.py | 3 -- 3 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 test/test_segmentation/test_postprocessing.py diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 9c5ca97..fd77add 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -6,7 +6,6 @@ import numpy as np import nifty.tools as nt import pandas as pd -import vigra from elf.io import open_file from scipy.spatial import distance @@ -113,15 +112,14 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray: # Filter the segmentation based on a spatial statistics from above. # -# FIXME: functions causes ValueError by using arrays of different lengths - +# FIXME: this computes the distance in pixels, but the MoBIE table contains it in physical units (=nm) +# This is inconsistent. def _compute_table(segmentation): - segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True) props = measure.regionprops(segmentation) - coordinates = np.array([prop.centroid for prop in props])[1:] - label_ids = np.unique(segmentation)[1:] - sizes = np.array([prop.area for prop in props])[1:] + label_ids = np.array([prop.label for prop in props]) + coordinates = np.array([prop.centroid for prop in props]) + sizes = np.array([prop.area for prop in props]) table = pd.DataFrame({ "label_id": label_ids, "n_pixels": sizes, diff --git a/test/test_segmentation/test_postprocessing.py b/test/test_segmentation/test_postprocessing.py new file mode 100644 index 0000000..531e2d2 --- /dev/null +++ b/test/test_segmentation/test_postprocessing.py @@ -0,0 +1,49 @@ +import os +import tempfile +import unittest + +from elf.io import open_file +from skimage.data import binary_blobs +from skimage.measure import label + + +class TestPostprocessing(unittest.TestCase): + def _create_example_seg(self, tmp_dir): + seg = binary_blobs(256, n_dim=3, volume_fraction=0.2) + seg = label(seg) + return seg + + def _test_postprocessing(self, spatial_statistics, threshold, **spatial_statistics_kwargs): + from flamingo_tools.segmentation.postprocessing import filter_segmentation + + with tempfile.TemporaryDirectory() as tmp_dir: + example_seg = self._create_example_seg(tmp_dir) + output_path = os.path.join(tmp_dir, "test-output.zarr") + output_key = "seg-filtered" + filter_segmentation( + example_seg, output_path, spatial_statistics, threshold, + output_key=output_key, **spatial_statistics_kwargs + ) + self.assertTrue(os.path.exists(output_path)) + with open_file(output_path, "r") as f: + filtered_seg = f[output_key][:] + self.assertEqual(filtered_seg.shape, example_seg.shape) + + def test_nearest_neighbor_distance(self): + from flamingo_tools.segmentation.postprocessing import nearest_neighbor_distance + + self._test_postprocessing(nearest_neighbor_distance, threshold=5) + + def test_local_ripleys_k(self): + from flamingo_tools.segmentation.postprocessing import local_ripleys_k + + self._test_postprocessing(local_ripleys_k, threshold=0.5) + + def test_neighbors_in_radius(self): + from flamingo_tools.segmentation.postprocessing import neighbors_in_radius + + self._test_postprocessing(neighbors_in_radius, threshold=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_segmentation/test_unet_prediction.py b/test/test_segmentation/test_unet_prediction.py index 9ea1f5c..9038bcc 100644 --- a/test/test_segmentation/test_unet_prediction.py +++ b/test/test_segmentation/test_unet_prediction.py @@ -1,5 +1,4 @@ import os -import sys import tempfile import unittest @@ -9,8 +8,6 @@ import z5py from torch_em.model import UNet3d -sys.path.append("../..") - class TestUnetPrediction(unittest.TestCase): shape = (64, 128, 128) From 17bae381aa9e664206c71181683bc06bc276ba7a Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 24 Apr 2025 13:05:07 +0200 Subject: [PATCH 7/7] Fixed mismatch of pixel distances and physical units --- flamingo_tools/segmentation/postprocessing.py | 10 ++++++---- scripts/prediction/postprocess_seg.py | 3 +++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index fd77add..7ad987b 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -113,12 +113,12 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray: # -# FIXME: this computes the distance in pixels, but the MoBIE table contains it in physical units (=nm) -# This is inconsistent. -def _compute_table(segmentation): +def _compute_table(segmentation, resolution): props = measure.regionprops(segmentation) label_ids = np.array([prop.label for prop in props]) coordinates = np.array([prop.centroid for prop in props]) + # transform pixel distance to physical units + coordinates = coordinates * resolution sizes = np.array([prop.area for prop in props]) table = pd.DataFrame({ "label_id": label_ids, @@ -137,6 +137,7 @@ def filter_segmentation( threshold: float, min_size: int = 1000, table: Optional[pd.DataFrame] = None, + resolution: float = 0.38, output_key: str = "segmentation_postprocessed", **spatial_statistics_kwargs, ) -> Tuple[int, int]: @@ -152,6 +153,7 @@ def filter_segmentation( threshold: Distance in micrometer to check for neighbors min_size: Minimal number of pixels for filtering small instances table: Dataframe of segmentation table + resolution: Resolution of segmentation in micrometer output_key: Output key for postprocessed segmentation spatial_statistics_kwargs: Arguments for spatial statistics function @@ -162,7 +164,7 @@ def filter_segmentation( # Compute the table on the fly. # NOTE: this currently doesn't work for large segmentations. if table is None: - table = _compute_table(segmentation) + table = _compute_table(segmentation, resolution=resolution) n_ids = len(table) # First apply the size filter. diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 4869361..0134539 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -23,6 +23,8 @@ def main(): help="The key / internal path of the segmentation.") parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", help="The key / internal path of the output.") + parser.add_argument('-r', "--resolution", type=float, default=0.38, + help="Resolution of segmentation in micrometer.") parser.add_argument("--s3_input", type=str, default=None, help="Input file path on S3 bucket.") parser.add_argument("--s3_credentials", type=str, default=None, @@ -102,6 +104,7 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold): spatial_statistics=spatial_statistics, threshold=threshold, min_size=args.min_size, table=tsv_table, + resolution=args.resolution, output_key=args.output_key, **spatial_statistics_kwargs) print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}")