Skip to content

Commit 9984b6c

Browse files
Merge pull request #46 from computational-cell-analytics/grid-search
Implement grid-search for IHC segmentation
2 parents 2a3160b + 4c1c8a4 commit 9984b6c

File tree

4 files changed

+134
-9
lines changed

4 files changed

+134
-9
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import warnings
1010
from concurrent import futures
11+
from functools import partial
1112
from typing import Optional, Tuple
1213

1314
import elf.parallel as parallel
@@ -17,10 +18,11 @@
1718
import torch
1819
import z5py
1920

20-
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
21+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper, SimpleTransformationWrapperWithHalo
2122
from elf.wrapper.base import MultiTransformationWrapper
2223
from elf.wrapper.resized_volume import ResizedVolume
2324
from elf.io import open_file
25+
from skimage.filters import gaussian
2426
from torch_em.util import load_model
2527
from torch_em.util.prediction import predict_with_halo
2628
from tqdm import tqdm
@@ -278,6 +280,7 @@ def distance_watershed_implementation(
278280
center_distance_threshold: float = 0.4,
279281
boundary_distance_threshold: Optional[float] = None,
280282
fg_threshold: float = 0.5,
283+
distance_smoothing: float = 0.0,
281284
original_shape: Optional[Tuple[int, int, int]] = None
282285
) -> None:
283286
"""Parallel implementation of the distance-prediction based watershed.
@@ -290,6 +293,8 @@ def distance_watershed_implementation(
290293
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
291294
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
292295
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
296+
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
297+
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
293298
original_shape: The original shape to resize the segmentation to.
294299
"""
295300
if isinstance(input_path, str):
@@ -307,11 +312,14 @@ def distance_watershed_implementation(
307312
center_distances = SelectChannel(input_, 1)
308313
boundary_distances = SelectChannel(input_, 2)
309314

310-
# Apply (lazy) smoothing to both.
311-
# NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
312-
# smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
313-
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
314-
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
315+
# Apply (lazy) smoothing to both channels if distance smoothing was set.
316+
if distance_smoothing > 0:
317+
smooth = partial(gaussian, sigma=distance_smoothing)
318+
# We assume that the gaussian is truncated at 5.3 sigma (tolerance of 1e-6)
319+
halo = int(np.ceil(5.3 * distance_smoothing))
320+
halo = 3 * (halo,)
321+
center_distances = SimpleTransformationWrapperWithHalo(center_distances, transformation=smooth, halo=halo)
322+
boundary_distances = SimpleTransformationWrapperWithHalo(boundary_distances, transformation=smooth, halo=halo)
315323

316324
# Allocate the (zarr) array for the seeds.
317325
if output_folder is None:
@@ -427,6 +435,7 @@ def run_unet_prediction(
427435
center_distance_threshold: float = 0.4,
428436
boundary_distance_threshold: Optional[float] = None,
429437
fg_threshold: float = 0.5,
438+
distance_smoothing: float = 0.0,
430439
seg_class: Optional[str] = None,
431440
) -> None:
432441
"""Run prediction and segmentation with a distance U-Net.
@@ -446,6 +455,8 @@ def run_unet_prediction(
446455
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
447456
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
448457
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
458+
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
459+
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
449460
seg_class: Specifier for exclusion criterias for mask generation.
450461
"""
451462
if output_folder is not None:
@@ -470,7 +481,8 @@ def run_unet_prediction(
470481
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
471482
center_distance_threshold=center_distance_threshold,
472483
boundary_distance_threshold=boundary_distance_threshold,
473-
fg_threshold=fg_threshold
484+
fg_threshold=fg_threshold,
485+
distance_smoothing=distance_smoothing,
474486
)
475487

476488
return segmentation
@@ -590,6 +602,7 @@ def run_unet_segmentation_slurm(
590602
center_distance_threshold: float = 0.4,
591603
boundary_distance_threshold: float = 0.5,
592604
fg_threshold: float = 0.5,
605+
distance_smoothing: float = 0.0,
593606
) -> None:
594607
"""Create segmentation from prediction.
595608
@@ -600,10 +613,13 @@ def run_unet_segmentation_slurm(
600613
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
601614
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
602615
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
616+
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
617+
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
603618
"""
604619
min_size = int(min_size)
605620
pmap_out = os.path.join(output_folder, "predictions.zarr")
606621
distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold,
607622
boundary_distance_threshold=boundary_distance_threshold,
608623
fg_threshold=fg_threshold,
624+
distance_smoothing=distance_smoothing,
609625
min_size=min_size)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import numpy as np
6+
from skimage.measure import label
7+
from skimage.segmentation import relabel_sequential
8+
9+
from torch_em.util import load_model
10+
from torch_em.util.grid_search import DistanceBasedInstanceSegmentation, instance_segmentation_grid_search
11+
12+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-07-for-grid-search"
13+
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa
14+
15+
GRID_SEARCH_VALUES = {
16+
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
17+
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
18+
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
19+
"min_size": [0],
20+
}
21+
22+
23+
def preprocess_gt():
24+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True))
25+
for label_path in label_paths:
26+
seg = imageio.imread(label_path)
27+
seg = label(seg)
28+
29+
min_size = 750
30+
ids, sizes = np.unique(seg, return_counts=True)
31+
filter_ids = ids[sizes < min_size]
32+
seg[np.isin(seg, filter_ids)] = 0
33+
seg, _, _ = relabel_sequential(seg)
34+
35+
out_path = label_path.replace("_corrected", "_gt")
36+
imageio.imwrite(out_path, seg)
37+
38+
39+
def run_grid_search():
40+
image_paths = sorted(glob(os.path.join(ROOT, "**/*lut3*.tif"), recursive=True))
41+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_gt.tif"), recursive=True))
42+
assert len(image_paths) == len(label_paths), f"{len(image_paths)}, {len(label_paths)}"
43+
result_dir = "ihc-v4-gs"
44+
45+
block_shape = (96, 256, 256)
46+
halo = (8, 64, 64)
47+
model = load_model(MODEL_PATH)
48+
segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo)
49+
50+
best_kwargs, best_score = instance_segmentation_grid_search(
51+
segmenter, image_paths, label_paths, result_dir, grid_search_values=GRID_SEARCH_VALUES
52+
)
53+
print("Grid-search result:")
54+
print(best_kwargs)
55+
print(best_score)
56+
57+
58+
# TODO plot the grid search results
59+
def evaluate_grid_search():
60+
import matplotlib.pyplot as plt
61+
import pandas as pd
62+
import seaborn as sns
63+
64+
result_dir = "ihc-v4-gs"
65+
criterion = "SA50"
66+
67+
gs_files = glob(os.path.join(result_dir, "*.csv"))
68+
gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
69+
70+
grid_search_parameters = list(GRID_SEARCH_VALUES.keys())
71+
72+
# Retrieve only the relevant columns and group by the gridsearch columns
73+
# and compute the mean value
74+
gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
75+
gs_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
76+
77+
# Find the best score and best result
78+
best_score, best_idx = gs_result[criterion].max(), gs_result[criterion].idxmax()
79+
best_params = gs_result.iloc[best_idx]
80+
best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
81+
82+
print("Best parameters:")
83+
print(best_kwargs)
84+
print("With score:", best_score)
85+
86+
fig, axes = plt.subplots(3)
87+
for i, (idx, col) in enumerate([
88+
("center_distance_threshold", "boundary_distance_threshold"),
89+
("distance_smoothing", "boundary_distance_threshold"),
90+
("distance_smoothing", "center_distance_threshold"),
91+
]):
92+
res = gs_result.groupby([idx, col]).mean().reset_index()
93+
res = res[[idx, col, criterion]]
94+
res = res.pivot(index=idx, columns=col, values=criterion)
95+
sns.heatmap(res, cmap="viridis", annot=True, fmt=".2g", cbar_kws={"label": criterion}, ax=axes[i])
96+
axes[i].set_xlabel(idx)
97+
axes[i].set_xlabel(col)
98+
plt.show()
99+
100+
101+
def main():
102+
# preprocess_gt()
103+
# run_grid_search()
104+
evaluate_grid_search()
105+
106+
107+
if __name__ == "__main__":
108+
main()

scripts/validation/IHCs/run_evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main():
5959
parser.add_argument("-i", "--input", default=ROOT)
6060
parser.add_argument("--folders", default=ANNOTATION_FOLDERS)
6161
parser.add_argument("--result_file", default="results.csv")
62-
parser.add_argument("--segmentation_name", default="IHC_v2")
62+
parser.add_argument("--segmentation_name", default="IHC_v4")
6363
parser.add_argument("--cache_folder")
6464
args = parser.parse_args()
6565
run_evaluation(args.input, args.folders, args.result_file, args.cache_folder, args.segmentation_name)

test/test_segmentation/test_unet_prediction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def test_run_unet_prediction_tif_mask(self):
6767

6868
def test_run_unet_prediction_complex_watershed(self):
6969
self._test_run_unet_prediction(
70-
use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5,
70+
use_tif=False, use_mask=True,
71+
center_distance_threshold=0.5, boundary_distance_threshold=0.5, distance_smoothing=1.0,
7172
)
7273

7374

0 commit comments

Comments
 (0)