diff --git a/environment.yaml b/environment.yaml index 05b23c6..fa924c9 100644 --- a/environment.yaml +++ b/environment.yaml @@ -12,6 +12,7 @@ dependencies: - pytorch - s3fs - torch_em + - trimesh - z5py # Don't install zarr v3, as we are not sure that it is compatible with MoBIE etc. yet - zarr <3 diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py new file mode 100644 index 0000000..442b4fe --- /dev/null +++ b/flamingo_tools/measurements.py @@ -0,0 +1,146 @@ +import multiprocessing as mp +from concurrent import futures +from typing import Optional + +import numpy as np +import pandas as pd +import trimesh +from skimage.measure import marching_cubes +from tqdm import tqdm + +from .file_utils import read_image_data +from .segmentation.postprocessing import compute_table_on_the_fly + + +def _measure_volume_and_surface(mask, resolution): + # Use marching_cubes for 3D data + verts, faces, normals, _ = marching_cubes(mask, spacing=(resolution,) * 3) + + mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) + surface = mesh.area + if mesh.is_watertight: + volume = np.abs(mesh.volume) + else: + volume = np.nan + + return volume, surface + + +def compute_object_measures_impl( + image: np.typing.ArrayLike, + segmentation: np.typing.ArrayLike, + n_threads: Optional[int] = None, + resolution: float = 0.38, + table: Optional[pd.DataFrame] = None, +) -> pd.DataFrame: + """Compute simple intensity and morphology measures for each segmented cell in a segmentation. + + See `compute_object_measures` for details. + + Args: + image: The image data. + segmentation: The segmentation. + n_threads: The number of threads to use for computation. + resolution: The resolution / voxel size of the data. + table: The segmentation table. Will be computed on the fly if it is not given. + + Returns: + The table with per object measurements. + """ + if table is None: + table = compute_table_on_the_fly(segmentation, resolution=resolution) + + def intensity_measures(seg_id): + # Get the bounding box. + row = table[table.label_id == seg_id] + + bb_min = np.array([ + row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item() + ]).astype("float32") / resolution + bb_min = np.round(bb_min, 0).astype("int32") + + bb_max = np.array([ + row.bb_max_z.item(), row.bb_max_y.item(), row.bb_max_x.item() + ]).astype("float32") / resolution + bb_max = np.round(bb_max, 0).astype("int32") + + bb = tuple( + slice(max(bmin - 1, 0), min(bmax + 1, sh)) + for bmin, bmax, sh in zip(bb_min, bb_max, image.shape) + ) + + local_image = image[bb] + mask = segmentation[bb] == seg_id + assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." + masked_intensity = local_image[mask] + + # Do the base intensity measurements. + measures = { + "label_id": seg_id, + "mean": np.mean(masked_intensity), + "stdev": np.std(masked_intensity), + "min": np.min(masked_intensity), + "max": np.max(masked_intensity), + "median": np.median(masked_intensity), + } + for percentile in (5, 10, 25, 75, 90, 95): + measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile) + + # Do the volume and surface measurement. + volume, surface = _measure_volume_and_surface(mask, resolution) + measures["volume"] = volume + measures["surface"] = surface + return measures + + seg_ids = table.label_id.values + assert len(seg_ids) > 0, "The segmentation table is empty." + n_threads = mp.cpu_count() if n_threads is None else n_threads + with futures.ThreadPoolExecutor(n_threads) as pool: + measures = list(tqdm( + pool.map(intensity_measures, seg_ids), total=len(seg_ids), desc="Compute intensity measures" + )) + + # Create the result table and save it. + keys = measures[0].keys() + measures = pd.DataFrame({k: [measure[k] for measure in measures] for k in keys}) + return measures + + +# Could also support s3 directly? +def compute_object_measures( + image_path: str, + segmentation_path: str, + segmentation_table_path: str, + output_table_path: str, + image_key: Optional[str] = None, + segmentation_key: Optional[str] = None, + n_threads: Optional[int] = None, + resolution: float = 0.38, +) -> None: + """Compute simple intensity and morphology measures for each segmented cell in a segmentation. + + This computes the mean, standard deviation, minimum, maximum, median and + 5th, 10th, 25th, 75th, 90th and 95th percentile of the intensity image + per cell, as well as the volume and surface. + + Args: + image_path: The filepath to the image data. Either a tif or hdf5/zarr/n5 file. + segmentation_path: The filepath to the segmentation data. Either a tif or hdf5/zarr/n5 file. + segmentation_table_path: The path to the segmentation table in MoBIE format. + output_table_path: The path for saving the segmentation with intensity measures. + image_key: The key (= internal path) for the image data. Not needed fir tif. + segmentation_key: The key (= internal path) for the segmentation data. Not needed for tif. + n_threads: The number of threads to use for computation. + resolution: The resolution / voxel size of the data. + """ + # First, we load the pre-computed segmentation table from MoBIE. + table = pd.read_csv(segmentation_table_path, sep="\t") + + # Then, open the volumes. + image = read_image_data(image_path, image_key) + segmentation = read_image_data(segmentation_path, segmentation_key) + + measures = compute_object_measures_impl( + image, segmentation, n_threads, resolution, table=table + ) + measures.to_csv(output_table_path, sep="\t", index=False) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 75529b6..3b24ce9 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -115,19 +115,39 @@ def neighbors_in_radius(table: pd.DataFrame, radius: float = 15) -> np.ndarray: # -def _compute_table(segmentation, resolution): +def compute_table_on_the_fly(segmentation: np.typing.ArrayLike, resolution: float) -> pd.DataFrame: + """Compute a segmentation table compatible with MoBIE. + + The table contains information about the number of pixels per object, + the anchor (= centroid) and the bounding box. Anchor and bounding box are given in physical coordinates. + + Args: + segmentation: The segmentation for which to compute the table. + resolution: The physical voxel spacing of the data. + + Returns: + The segmentation table. + """ props = measure.regionprops(segmentation) label_ids = np.array([prop.label for prop in props]) - coordinates = np.array([prop.centroid for prop in props]) + coordinates = np.array([prop.centroid for prop in props]).astype("float32") # transform pixel distance to physical units coordinates = coordinates * resolution + bb_min = np.array([prop.bbox[:3] for prop in props]).astype("float32") * resolution + bb_max = np.array([prop.bbox[3:] for prop in props]).astype("float32") * resolution sizes = np.array([prop.area for prop in props]) table = pd.DataFrame({ "label_id": label_ids, - "n_pixels": sizes, "anchor_x": coordinates[:, 2], "anchor_y": coordinates[:, 1], "anchor_z": coordinates[:, 0], + "bb_min_x": bb_min[:, 2], + "bb_min_y": bb_min[:, 1], + "bb_min_z": bb_min[:, 0], + "bb_max_x": bb_max[:, 2], + "bb_max_y": bb_max[:, 1], + "bb_max_z": bb_max[:, 0], + "n_pixels": sizes, }) return table @@ -160,13 +180,12 @@ def filter_segmentation( spatial_statistics_kwargs: Arguments for spatial statistics function Returns: - n_ids - n_ids_filtered + The number of objects before filtering. + The number of objects after filtering. """ - # Compute the table on the fly. - # NOTE: this currently doesn't work for large segmentations. + # Compute the table on the fly. This doesn't work for large segmentations. if table is None: - table = _compute_table(segmentation, resolution=resolution) + table = compute_table_on_the_fly(segmentation, resolution=resolution) n_ids = len(table) # First apply the size filter. diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index a1d1ba8..8b754d3 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -324,6 +324,7 @@ def run_unet_prediction( scale: Optional[float] = None, block_shape: Optional[Tuple[int, int, int]] = None, halo: Optional[Tuple[int, int, int]] = None, + use_mask: bool = True, ) -> None: """Run prediction and segmentation with a distance U-Net. @@ -337,10 +338,12 @@ def run_unet_prediction( By default the data will not be rescaled. block_shape: The block-shape for running the prediction. halo: The halo (= block overlap) to use for prediction. + use_mask: Whether to use the masking heuristics to not run inference on empty blocks. """ os.makedirs(output_folder, exist_ok=True) - find_mask(input_path, input_key, output_folder) + if use_mask: + find_mask(input_path, input_key, output_folder) original_shape = prediction_impl( input_path, input_key, output_folder, model_path, scale, block_shape, halo diff --git a/flamingo_tools/test_data.py b/flamingo_tools/test_data.py index bca605d..7450a3b 100644 --- a/flamingo_tools/test_data.py +++ b/flamingo_tools/test_data.py @@ -1,7 +1,75 @@ import os +from typing import Tuple import imageio.v3 as imageio -from skimage.data import binary_blobs +import requests +from skimage.data import binary_blobs, cells3d +from skimage.measure import label + +from .segmentation.postprocessing import compute_table_on_the_fly + +SEGMENTATION_URL = "https://owncloud.gwdg.de/index.php/s/kwoGRYiJRRrswgw/download" + + +def get_test_volume_and_segmentation(folder: str) -> Tuple[str, str, str]: + """Download a small volume with nuclei and corresponding segmentation. + + Args: + folder: The test data folder. The data will be downloaded to this folder. + + Returns: + The path to the image, stored as tif. + The path to the segmentation, stored as tif. + The path to the segmentation table, stored as tsv. + """ + os.makedirs(folder, exist_ok=True) + + segmentation_path = os.path.join(folder, "segmentation.tif") + resp = requests.get(SEGMENTATION_URL) + resp.raise_for_status() + + with open(segmentation_path, "wb") as f: + f.write(resp.content) + + nuclei = cells3d()[20:40, 1] + segmentation = imageio.imread(segmentation_path) + assert nuclei.shape == segmentation.shape + + image_path = os.path.join(folder, "image.tif") + imageio.imwrite(image_path, nuclei) + + table_path = os.path.join(folder, "default.tsv") + table = compute_table_on_the_fly(segmentation, resolution=0.38) + table.to_csv(table_path, sep="\t", index=False) + + return image_path, segmentation_path, table_path + + +def create_image_data_and_segmentation(folder: str, size: int = 256) -> Tuple[str, str, str]: + """Create test data containing an image, a corresponding segmentation and segmentation table. + + Args: + folder: The test data folder. The data will be written to this folder. + + Returns: + The path to the image, stored as tif. + The path to the segmentation, stored as tif. + The path to the segmentation table, stored as tsv. + """ + os.makedirs(folder, exist_ok=True) + data = binary_blobs(size, n_dim=3).astype("uint8") * 255 + seg = label(data) + + image_path = os.path.join(folder, "image.tif") + segmentation_path = os.path.join(folder, "segmentation.tif") + imageio.imwrite(image_path, data) + imageio.imwrite(segmentation_path, seg) + + table_path = os.path.join(folder, "default.tsv") + table = compute_table_on_the_fly(seg, resolution=0.38) + table.to_csv(table_path, sep="\t", index=False) + + return image_path, segmentation_path, table_path # TODO add metadata diff --git a/scripts/measurements/measure_sgns.py b/scripts/measurements/measure_sgns.py new file mode 100644 index 0000000..f94730a --- /dev/null +++ b/scripts/measurements/measure_sgns.py @@ -0,0 +1,48 @@ +import json +import os + +import numpy as np +import pandas as pd +from flamingo_tools.s3_utils import create_s3_target, BUCKET_NAME + + +def open_json(fs, path): + s3_path = os.path.join(BUCKET_NAME, path) + with fs.open(s3_path, "r") as f: + content = json.load(f) + return content + + +def open_tsv(fs, path): + s3_path = os.path.join(BUCKET_NAME, path) + with fs.open(s3_path, "r") as f: + table = pd.read_csv(f, sep="\t") + return table + + +def main(): + fs = create_s3_target() + project_info = open_json(fs, "project.json") + for dataset in project_info["datasets"]: + if dataset == "fens": + continue + print(dataset) + dataset_info = open_json(fs, os.path.join(dataset, "dataset.json")) + sources = dataset_info["sources"] + for source, source_info in sources.items(): + if not source.startswith("SGN"): + continue + assert "segmentation" in source_info + source_info = source_info["segmentation"] + table_path = source_info["tableData"]["tsv"]["relativePath"] + table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv")) + component_labels = table.component_labels.values + remaining_sgns = component_labels[component_labels != 0] + print(source) + print("Number of SGNs (all components) :", len(remaining_sgns)) + _, n_per_component = np.unique(remaining_sgns, return_counts=True) + print("Number of SGNs (largest component):", max(n_per_component)) + + +if __name__ == "__main__": + main() diff --git a/scripts/sgn_stain_predictions/check_segmentation.py b/scripts/sgn_stain_predictions/check_segmentation.py new file mode 100644 index 0000000..3a4799c --- /dev/null +++ b/scripts/sgn_stain_predictions/check_segmentation.py @@ -0,0 +1,35 @@ +import os +from glob import glob + +import imageio.v3 as imageio +import napari + + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" +SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" + + +def main(): + files = sorted(glob(os.path.join(ROOT, "**/*.tif"))) + for ff in files: + if "segmentations" in ff: + return + print("Visualizing", ff) + rel_path = os.path.relpath(ff, ROOT) + seg_path = os.path.join(SAVE_ROOT, rel_path) + + image = imageio.imread(ff) + if os.path.exists(seg_path): + seg = imageio.imread(seg_path) + else: + seg = None + + v = napari.Viewer() + v.add_image(image) + if seg is not None: + v.add_labels(seg) + napari.run() + + +if __name__ == "__main__": + main() diff --git a/scripts/sgn_stain_predictions/measure_intensities.py b/scripts/sgn_stain_predictions/measure_intensities.py new file mode 100644 index 0000000..99f67ac --- /dev/null +++ b/scripts/sgn_stain_predictions/measure_intensities.py @@ -0,0 +1,39 @@ +import os +from glob import glob + +import tifffile +from flamingo_tools.measurements import compute_object_measures_impl + + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" +SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" + + +def measure_intensities(ff): + rel_path = os.path.relpath(ff, ROOT) + out_path = os.path.join("./measurements", rel_path.replace(".tif", ".xlsx")) + if os.path.exists(out_path): + return + + print("Computing measurements for", rel_path) + seg_path = os.path.join(SAVE_ROOT, rel_path) + + image_data = tifffile.memmap(ff) + seg_data = tifffile.memmap(seg_path) + + table = compute_object_measures_impl(image_data, seg_data, n_threads=8) + + os.makedirs(os.path.split(out_path)[0], exist_ok=True) + table.to_excel(out_path, index=False) + + +def main(): + files = sorted(glob(os.path.join(ROOT, "**/*.tif"))) + for ff in files: + if "segmentations" in ff: + return + measure_intensities(ff) + + +if __name__ == "__main__": + main() diff --git a/scripts/sgn_stain_predictions/run_prediction.py b/scripts/sgn_stain_predictions/run_prediction.py new file mode 100644 index 0000000..1badae3 --- /dev/null +++ b/scripts/sgn_stain_predictions/run_prediction.py @@ -0,0 +1,58 @@ +import os +import tempfile +from glob import glob + +import tifffile +from elf.io import open_file +from flamingo_tools.segmentation import run_unet_prediction + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" +MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa + +SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" + + +def check_data(): + files = glob(os.path.join(ROOT, "**/*.tif"), recursive=True) + for ff in files: + rel_path = sorted(os.path.relpath(ff, ROOT)) + shape = tifffile.memmap(ff).shape + print(rel_path, shape) + + +def segment_crop(input_file): + fname = os.path.relpath(input_file, ROOT) + out_file = os.path.join(SAVE_ROOT, fname) + if "segmentations" in input_file: + return + if os.path.exists(out_file): + return + + print("Run prediction for", input_file) + os.makedirs(os.path.split(out_file)[0], exist_ok=True) + with tempfile.TemporaryDirectory() as tmp_folder: + run_unet_prediction( + input_file, input_key=None, output_folder=tmp_folder, + model_path=MODEL_PATH, min_size=1000, use_mask=False, + ) + seg_path = os.path.join(tmp_folder, "segmentation.zarr") + with open_file(seg_path, mode="r") as f: + seg = f["segmentation"][:] + + print("Writing output to", out_file) + tifffile.imwrite(out_file, seg, bigtiff=True) + + +def segment_all(): + files = sorted(glob(os.path.join(ROOT, "**/*.tif"), recursive=True)) + for ff in files: + segment_crop(ff) + + +def main(): + # check_data() + segment_all() + + +if __name__ == "__main__": + main() diff --git a/test/test_measurements.py b/test/test_measurements.py new file mode 100644 index 0000000..7ab67e3 --- /dev/null +++ b/test/test_measurements.py @@ -0,0 +1,60 @@ +import os +import unittest +from shutil import rmtree + +import imageio.v3 as imageio +import pandas as pd +import numpy as np +from skimage.measure import regionprops_table + + +class TestMeasurements(unittest.TestCase): + folder = "./tmp" + + def setUp(self): + from flamingo_tools.test_data import get_test_volume_and_segmentation + + self.image_path, self.seg_path, self.table_path = get_test_volume_and_segmentation(self.folder) + + def tearDown(self): + try: + rmtree(self.folder) + except Exception: + pass + + def test_compute_object_measures(self): + from flamingo_tools.measurements import compute_object_measures + + output_path = os.path.join(self.folder, "measurements.tsv") + compute_object_measures( + self.image_path, self.seg_path, self.table_path, output_path, n_threads=1 + ) + self.assertTrue(os.path.exists(output_path)) + + table = pd.read_csv(output_path, sep="\t") + self.assertTrue(len(table) >= 1) + expected_columns = ["label_id", "mean", "stdev", "min", "max", "median"] + expected_columns.extend([f"percentile-{p}" for p in (5, 10, 25, 75, 90, 95)]) + expected_columns.extend(["volume", "surface"]) + for col in expected_columns: + self.assertIn(col, table.columns) + + n_objects = int(imageio.imread(self.seg_path).max()) + expected_shape = (n_objects, len(expected_columns)) + self.assertEqual(table.shape, expected_shape) + + image = imageio.imread(self.image_path) + segmentation = imageio.imread(self.seg_path) + properties = ("label", "intensity_mean", "intensity_std", "intensity_min", "intensity_max") + expected_measures = regionprops_table(segmentation, intensity_image=image, properties=properties) + expected_measures = pd.DataFrame(expected_measures) + + for (col, col_exp) in [ + ("label_id", "label"), ("mean", "intensity_mean"), ("stdev", "intensity_std"), + ("min", "intensity_min"), ("max", "intensity_max"), + ]: + self.assertTrue(np.allclose(table[col].values, expected_measures[col_exp].values)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_segmentation/test_postprocessing.py b/test/test_segmentation/test_postprocessing.py index 531e2d2..c8be572 100644 --- a/test/test_segmentation/test_postprocessing.py +++ b/test/test_segmentation/test_postprocessing.py @@ -2,9 +2,12 @@ import tempfile import unittest +import imageio.v3 as imageio +import numpy as np +import pandas as pd from elf.io import open_file from skimage.data import binary_blobs -from skimage.measure import label +from skimage.measure import label, regionprops_table class TestPostprocessing(unittest.TestCase): @@ -44,6 +47,32 @@ def test_neighbors_in_radius(self): self._test_postprocessing(neighbors_in_radius, threshold=5) + def test_compute_table_on_the_fly(self): + from flamingo_tools.segmentation.postprocessing import compute_table_on_the_fly + from flamingo_tools.test_data import get_test_volume_and_segmentation + + with tempfile.TemporaryDirectory() as tmp_dir: + _, seg_path, _ = get_test_volume_and_segmentation(tmp_dir) + segmentation = imageio.imread(seg_path) + + resolution = 0.38 + table = compute_table_on_the_fly(segmentation, resolution=resolution) + + properties = ("label", "bbox", "centroid") + expected_table = regionprops_table(segmentation, properties=properties) + expected_table = pd.DataFrame(expected_table) + + for (col, col_exp) in [ + ("label_id", "label"), + ("anchor_x", "centroid-2"), ("anchor_y", "centroid-1"), ("anchor_z", "centroid-0"), + ("bb_min_x", "bbox-2"), ("bb_min_y", "bbox-1"), ("bb_min_z", "bbox-0"), + ("bb_max_x", "bbox-5"), ("bb_max_y", "bbox-4"), ("bb_max_z", "bbox-3"), + ]: + values = table[col].values + if col != "label_id": + values /= resolution + self.assertTrue(np.allclose(values, expected_table[col_exp].values)) + if __name__ == "__main__": unittest.main()