Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import warnings
from concurrent import futures
from functools import partial
from typing import Optional, Tuple

import elf.parallel as parallel
Expand All @@ -17,10 +18,11 @@
import torch
import z5py

from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper, SimpleTransformationWrapperWithHalo
from elf.wrapper.base import MultiTransformationWrapper
from elf.wrapper.resized_volume import ResizedVolume
from elf.io import open_file
from skimage.filters import gaussian
from torch_em.util import load_model
from torch_em.util.prediction import predict_with_halo
from tqdm import tqdm
Expand Down Expand Up @@ -278,6 +280,7 @@ def distance_watershed_implementation(
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
distance_smoothing: float = 0.0,
original_shape: Optional[Tuple[int, int, int]] = None
) -> None:
"""Parallel implementation of the distance-prediction based watershed.
Expand All @@ -290,6 +293,8 @@ def distance_watershed_implementation(
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
original_shape: The original shape to resize the segmentation to.
"""
if isinstance(input_path, str):
Expand All @@ -307,11 +312,14 @@ def distance_watershed_implementation(
center_distances = SelectChannel(input_, 1)
boundary_distances = SelectChannel(input_, 2)

# Apply (lazy) smoothing to both.
# NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
# smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
# Apply (lazy) smoothing to both channels if distance smoothing was set.
if distance_smoothing > 0:
smooth = partial(gaussian, sigma=distance_smoothing)
# We assume that the gaussian is truncated at 5.3 sigma (tolerance of 1e-6)
halo = int(np.ceil(5.3 * distance_smoothing))
halo = 3 * (halo,)
center_distances = SimpleTransformationWrapperWithHalo(center_distances, transformation=smooth, halo=halo)
boundary_distances = SimpleTransformationWrapperWithHalo(boundary_distances, transformation=smooth, halo=halo)

# Allocate the (zarr) array for the seeds.
if output_folder is None:
Expand Down Expand Up @@ -427,6 +435,7 @@ def run_unet_prediction(
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
distance_smoothing: float = 0.0,
seg_class: Optional[str] = None,
) -> None:
"""Run prediction and segmentation with a distance U-Net.
Expand All @@ -446,6 +455,8 @@ def run_unet_prediction(
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
seg_class: Specifier for exclusion criterias for mask generation.
"""
if output_folder is not None:
Expand All @@ -470,7 +481,8 @@ def run_unet_prediction(
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
fg_threshold=fg_threshold
fg_threshold=fg_threshold,
distance_smoothing=distance_smoothing,
)

return segmentation
Expand Down Expand Up @@ -590,6 +602,7 @@ def run_unet_segmentation_slurm(
center_distance_threshold: float = 0.4,
boundary_distance_threshold: float = 0.5,
fg_threshold: float = 0.5,
distance_smoothing: float = 0.0,
) -> None:
"""Create segmentation from prediction.

Expand All @@ -600,10 +613,13 @@ def run_unet_segmentation_slurm(
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
"""
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
fg_threshold=fg_threshold,
distance_smoothing=distance_smoothing,
min_size=min_size)
108 changes: 108 additions & 0 deletions scripts/validation/IHCs/grid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
from glob import glob

import imageio.v3 as imageio
import numpy as np
from skimage.measure import label
from skimage.segmentation import relabel_sequential

from torch_em.util import load_model
from torch_em.util.grid_search import DistanceBasedInstanceSegmentation, instance_segmentation_grid_search

ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-07-for-grid-search"
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa

GRID_SEARCH_VALUES = {
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
"min_size": [0],
}


def preprocess_gt():
label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True))
for label_path in label_paths:
seg = imageio.imread(label_path)
seg = label(seg)

min_size = 750
ids, sizes = np.unique(seg, return_counts=True)
filter_ids = ids[sizes < min_size]
seg[np.isin(seg, filter_ids)] = 0
seg, _, _ = relabel_sequential(seg)

out_path = label_path.replace("_corrected", "_gt")
imageio.imwrite(out_path, seg)


def run_grid_search():
image_paths = sorted(glob(os.path.join(ROOT, "**/*lut3*.tif"), recursive=True))
label_paths = sorted(glob(os.path.join(ROOT, "**/*_gt.tif"), recursive=True))
assert len(image_paths) == len(label_paths), f"{len(image_paths)}, {len(label_paths)}"
result_dir = "ihc-v4-gs"

block_shape = (96, 256, 256)
halo = (8, 64, 64)
model = load_model(MODEL_PATH)
segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo)

best_kwargs, best_score = instance_segmentation_grid_search(
segmenter, image_paths, label_paths, result_dir, grid_search_values=GRID_SEARCH_VALUES
)
print("Grid-search result:")
print(best_kwargs)
print(best_score)


# TODO plot the grid search results
def evaluate_grid_search():
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

result_dir = "ihc-v4-gs"
criterion = "SA50"

gs_files = glob(os.path.join(result_dir, "*.csv"))
gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])

grid_search_parameters = list(GRID_SEARCH_VALUES.keys())

# Retrieve only the relevant columns and group by the gridsearch columns
# and compute the mean value
gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
gs_result = gs_result.groupby(grid_search_parameters).mean().reset_index()

# Find the best score and best result
best_score, best_idx = gs_result[criterion].max(), gs_result[criterion].idxmax()
best_params = gs_result.iloc[best_idx]
best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}

print("Best parameters:")
print(best_kwargs)
print("With score:", best_score)

fig, axes = plt.subplots(3)
for i, (idx, col) in enumerate([
("center_distance_threshold", "boundary_distance_threshold"),
("distance_smoothing", "boundary_distance_threshold"),
("distance_smoothing", "center_distance_threshold"),
]):
res = gs_result.groupby([idx, col]).mean().reset_index()
res = res[[idx, col, criterion]]
res = res.pivot(index=idx, columns=col, values=criterion)
sns.heatmap(res, cmap="viridis", annot=True, fmt=".2g", cbar_kws={"label": criterion}, ax=axes[i])
axes[i].set_xlabel(idx)
axes[i].set_xlabel(col)
plt.show()


def main():
# preprocess_gt()
# run_grid_search()
evaluate_grid_search()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion scripts/validation/IHCs/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main():
parser.add_argument("-i", "--input", default=ROOT)
parser.add_argument("--folders", default=ANNOTATION_FOLDERS)
parser.add_argument("--result_file", default="results.csv")
parser.add_argument("--segmentation_name", default="IHC_v2")
parser.add_argument("--segmentation_name", default="IHC_v4")
parser.add_argument("--cache_folder")
args = parser.parse_args()
run_evaluation(args.input, args.folders, args.result_file, args.cache_folder, args.segmentation_name)
Expand Down
3 changes: 2 additions & 1 deletion test/test_segmentation/test_unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_run_unet_prediction_tif_mask(self):

def test_run_unet_prediction_complex_watershed(self):
self._test_run_unet_prediction(
use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5,
use_tif=False, use_mask=True,
center_distance_threshold=0.5, boundary_distance_threshold=0.5, distance_smoothing=1.0,
)


Expand Down
Loading